Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

All notable changes to this project will be documented in this file.

## [1.1.0] - 2026-02-05
- Added `softmax_cap` parameter to `pivotal_attention3` for improved numerical stability.
- Added LRGB example script.

## [1.0.0] - 2026-01-25
- Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
- Added `pivotal_attention3` functional API for 3-Floyd attention.
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ uv pip install -e .

## Changelog (latest)

- Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
- Added `pivotal_attention3` functional API for 3-Floyd attention.
- Added additional configuration options in `PivotalAttentionBlock`.
- Added `softmax_cap` parameter to `pivotal_attention3` for improved numerical stability.
- Added LRGB example script.


The full changelog is in [CHANGELOG.md](CHANGELOG.md).

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "floydnet"
version = "1.0.0"
version = "1.1.0"
description = "Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
16 changes: 15 additions & 1 deletion src/floydnet/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def pivotal_attention(
dropout: float = 0.0,
scale: Optional[float] = None,
inf: float = 1e9,
softmax_cap: float = -1,
) -> torch.Tensor:
"""Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".

Expand All @@ -47,6 +48,9 @@ def pivotal_attention(
dropout: Dropout probability applied to attention weights (only effective if > 0).
scale: Optional custom scaling factor. If None, defaults to 1/sqrt(2*D).
inf: Value to use for -infinity in masks.
softmax_cap: If > 0, applies a tanh-based logit cap before softmax.
Note: when using a non-boolean (additive) attn_mask, ensure its magnitude/semantics remain compatible
with capping (e.g., very large negative values used to approximate -inf can interact with logit shaping).

Returns:
Tensor of shape (B, H, L_i, L_k, D)
Expand All @@ -65,6 +69,9 @@ def pivotal_attention(
attn_scores = torch.einsum("bhikd,bhijd->bhikj", q_ik, k_ij) \
+ torch.einsum("bhikd,bhjkd->bhikj", q_ik, k_jk)

if softmax_cap > 0:
attn_scores = softmax_cap * torch.tanh(attn_scores / softmax_cap)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_scores = attn_scores.masked_fill(attn_mask, -inf)
Expand Down Expand Up @@ -93,6 +100,7 @@ def pivotal_attention3(
dropout: float = 0.0,
scale: Optional[float] = None,
inf: float = 1e9,
softmax_cap: float = -1,
) -> torch.Tensor:
"""3-Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".

Expand All @@ -111,9 +119,12 @@ def pivotal_attention3(
dropout: Dropout probability applied to attention weights (only effective if > 0).
scale: Optional custom scaling factor. If None, defaults to 1/sqrt(3*D).
inf: Value to use for -infinity in masks.
softmax_cap: If > 0, applies a tanh-based logit cap before softmax.
Note: when using a non-boolean (additive) attn_mask, ensure its magnitude/semantics remain compatible
with capping (e.g., very large negative values used to approximate -inf can interact with logit shaping).

Returns:
Tensor of shape (B, H, L_i, l_j, L_k, D)
Tensor of shape (B, H, L_i, L_j, L_k, D)
"""
assert all([t.dim() == 6 for t in [q_ijk, k_pjk, k_ipk, k_ijp, v_pjk, v_ipk, v_ijp]]), "All inputs must be 6D tensors"
B, H, L_i, L_j, L_k, D = q_ijk.shape
Expand All @@ -130,6 +141,9 @@ def pivotal_attention3(
attn_scores = torch.einsum("bhijkd,bhpjkd->bhijkp", q_ijk, k_pjk) \
+ torch.einsum("bhijkd,bhipkd->bhijkp", q_ijk, k_ipk) \
+ torch.einsum("bhijkd,bhijpd->bhijkp", q_ijk, k_ijp)

if softmax_cap > 0:
attn_scores = softmax_cap * torch.tanh(attn_scores / softmax_cap)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
Expand Down
Loading