diff --git a/CHANGELOG.md b/CHANGELOG.md
index d4ea80f7..5d948d3d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## Unreleased
+
+### Added
+- LiGR transformer layers from "From Features to Transformers: Redefining Ranking for Scalable Impact" ([#295](https://github.com/MobileTeleSystems/RecTools/pull/295))
+
## [0.16.0] - 27.07.2025
### Added
diff --git a/README.md b/README.md
index b383177b..8c58ccc7 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,15 @@ faster than ever before.
- In [HSTU tutorial](examples/tutorials/transformers_HSTU_tutorial.ipynb) we show that original metrics reported for HSTU on public Movielens datasets may actually be **underestimated**
- Configurable, customizable, callback-friendly, checkpoints-included, logs-out-of-the-box, custom-validation-ready, multi-gpu-compatible! See [Transformers Advanced Training User Guide](examples/tutorials/transformers_advanced_training_guide.ipynb) and [Transformers Customization Guide](examples/tutorials/transformers_customization_guide.ipynb)
+
+## β¨ Highlights: RecTools framework at ACM RecSys'25 β¨
+
+**RecTools implementations are featured in ACM RecSys'25: ["eSASRec: Enhancing Transformer-based Recommendations in a Modular Fashion"](https://www.arxiv.org/abs/2508.06450):**
+- The article presents a systematic benchmark of Transformer modifications using RecTools models. It offers a detailed evaluation of training objectives, Transformer architectures, loss functions, and negative sampling strategies in realistic, production-like settings
+- We introduce a new SOTA baseline, **eSASRec**, which combines SASRecβs training objective with LiGR Transformer layers and Sampled Softmax loss, forming a simple yet powerful recipe
+- **eSASRec** shows 23% boost over SOTA models, such as ActionPiece, on academic benchmarks
+- [LiGR](https://arxiv.org/pdf/2502.03417) Transformer layers used in **eSASRec** are now in RecTools
+
Plase note that we always compare the quality of our implementations to academic papers results. [Public benchmarks for transformer models SASRec and BERT4Rec](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) show that RecTools implementations achieve highest scores on multiple datasets compared to other published results.
@@ -107,7 +116,7 @@ The table below lists recommender models that are available in RecTools.
| Model | Type | Description (π for user/item features, π for warm inference, βοΈ for cold inference support) | Tutorials & Benchmarks |
|---------------------|----|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|
| HSTU | Neural Network | `rectools.models.HSTUModel` - Sequential model with unidirectional pointwise aggregated attention mechanism, incorporating relative attention bias from positional and temporal information, introduced in ["Actions speak louder then words..."](https://arxiv.org/pdf/2402.17152), combined with "Shifted Sequence" training objective as in original public benchmarks
π | π [HSTU Theory & Practice](examples/tutorials/transformers_HSTU_tutorial.ipynb)
π [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)
π [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb)
π [Top performance on public datasets](examples/tutorials/transformers_HSTU_tutorial.ipynb)
-| SASRec | Neural Network | `rectools.models.SASRecModel` - Transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective
π | π [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)
π [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb)
π [Customization guide](examples/tutorials/transformers_customization_guide.ipynb)
π [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
+| SASRec | Neural Network | `rectools.models.SASRecModel` - Transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective.
For eSASRec variant specify `rectools.models.nn.transformers.ligr.LiGRLayers` for `transformer_layers_type` and `sampled_softmax` for `loss`
π | π [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)
π [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb)
π [Customization guide](examples/tutorials/transformers_customization_guide.ipynb)
π [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
| BERT4Rec | Neural Network | `rectools.models.BERT4RecModel` - Transformer-based sequential model with bidirectional attention mechanism and "MLM" (masked item) training objective
π | π [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)
π [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb)
π [Customization guide](examples/tutorials/transformers_customization_guide.ipynb)
π [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
| [implicit](https://github.com/benfred/implicit) ALS Wrapper | Matrix Factorization | `rectools.models.ImplicitALSWrapperModel` - Alternating Least Squares Matrix Factorizattion algorithm for implicit feedback.
π | π [Theory & Practice](https://rectools.readthedocs.io/en/latest/examples/tutorials/baselines_extended_tutorial.html#Implicit-ALS)
π [50% boost to metrics with user & item features](examples/5_benchmark_iALS_with_features.ipynb) |
| [implicit](https://github.com/benfred/implicit) BPR-MF Wrapper | Matrix Factorization | `rectools.models.ImplicitBPRWrapperModel` - Bayesian Personalized Ranking Matrix Factorization algorithm. | π [Theory & Practice](https://rectools.readthedocs.io/en/latest/examples/tutorials/baselines_extended_tutorial.html#Bayesian-Personalized-Ranking-Matrix-Factorization-(BPR-MF)) |
diff --git a/rectools/models/nn/transformers/ligr.py b/rectools/models/nn/transformers/ligr.py
new file mode 100644
index 00000000..cb18b4fd
--- /dev/null
+++ b/rectools/models/nn/transformers/ligr.py
@@ -0,0 +1,177 @@
+import typing as tp
+
+import torch
+from torch import nn
+
+from rectools.models.nn.transformers.net_blocks import TransformerLayersBase
+
+from .net_blocks import init_feed_forward
+
+
+class LiGRLayer(nn.Module):
+ """
+ Transformer Layer as described in "From Features to Transformers:
+ Redefining Ranking for Scalable Impact" https://arxiv.org/pdf/2502.03417
+
+ Parameters
+ ----------
+ n_factors: int
+ Latent embeddings size.
+ n_heads: int
+ Number of attention heads.
+ dropout_rate: float
+ Probability of a hidden unit to be zeroed.
+ ff_factors_multiplier: int, default 4
+ Feed-forward layers latent embedding size multiplier.
+ bias_in_ff: bool, default ``False``
+ Add bias in Linear layers of Feed Forward
+ ff_activation: {"swiglu", "relu", "gelu"}, default "swiglu"
+ Activation function to use.
+ """
+
+ def __init__(
+ self,
+ n_factors: int,
+ n_heads: int,
+ dropout_rate: float,
+ ff_factors_multiplier: int = 4,
+ bias_in_ff: bool = False,
+ ff_activation: str = "swiglu",
+ ):
+ super().__init__()
+ self.multi_head_attn = nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True)
+ self.layer_norm_1 = nn.LayerNorm(n_factors)
+ self.dropout_1 = nn.Dropout(dropout_rate)
+ self.layer_norm_2 = nn.LayerNorm(n_factors)
+ self.feed_forward = init_feed_forward(n_factors, ff_factors_multiplier, dropout_rate, ff_activation, bias_in_ff)
+ self.dropout_2 = nn.Dropout(dropout_rate)
+
+ self.gating_linear_1 = nn.Linear(n_factors, n_factors)
+ self.gating_linear_2 = nn.Linear(n_factors, n_factors)
+
+ def forward(
+ self,
+ seqs: torch.Tensor,
+ attn_mask: tp.Optional[torch.Tensor],
+ key_padding_mask: tp.Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ Forward pass through transformer block.
+
+ Parameters
+ ----------
+ seqs: torch.Tensor
+ User sequences of item embeddings.
+ attn_mask: torch.Tensor, optional
+ Optional mask to use in forward pass of multi-head attention as `attn_mask`.
+ key_padding_mask: torch.Tensor, optional
+ Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
+
+
+ Returns
+ -------
+ torch.Tensor
+ User sequences passed through transformer layers.
+ """
+ mha_input = self.layer_norm_1(seqs)
+ mha_output, _ = self.multi_head_attn(
+ mha_input,
+ mha_input,
+ mha_input,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )
+ gated_skip = torch.nn.functional.sigmoid(self.gating_linear_1(seqs))
+ seqs = seqs + torch.mul(gated_skip, self.dropout_1(mha_output))
+
+ ff_input = self.layer_norm_2(seqs)
+ ff_output = self.feed_forward(ff_input)
+ gated_skip = torch.nn.functional.sigmoid(self.gating_linear_2(seqs))
+ seqs = seqs + torch.mul(gated_skip, self.dropout_2(ff_output))
+ return seqs
+
+
+class LiGRLayers(TransformerLayersBase):
+ """
+ LiGR Transformer blocks.
+
+ Parameters
+ ----------
+ n_blocks: int
+ Number of transformer blocks.
+ n_factors: int
+ Latent embeddings size.
+ n_heads: int
+ Number of attention heads.
+ dropout_rate: float
+ Probability of a hidden unit to be zeroed.
+ ff_factors_multiplier: int, default 4
+ Feed-forward layers latent embedding size multiplier. Pass in ``transformer_layers_kwargs`` to override.
+ ff_activation: {"swiglu", "relu", "gelu"}, default "swiglu"
+ Activation function to use. Pass in ``transformer_layers_kwargs`` to override.
+ bias_in_ff: bool, default ``False``
+ Add bias in Linear layers of Feed Forward. Pass in ``transformer_layers_kwargs`` to override.
+ """
+
+ def __init__(
+ self,
+ n_blocks: int,
+ n_factors: int,
+ n_heads: int,
+ dropout_rate: float,
+ ff_factors_multiplier: int = 4,
+ ff_activation: str = "swiglu",
+ bias_in_ff: bool = False,
+ ):
+ super().__init__()
+ self.n_blocks = n_blocks
+ self.n_factors = n_factors
+ self.n_heads = n_heads
+ self.dropout_rate = dropout_rate
+ self.ff_factors_multiplier = ff_factors_multiplier
+ self.ff_activation = ff_activation
+ self.bias_in_ff = bias_in_ff
+ self.transformer_blocks = nn.ModuleList([self._init_transformer_block() for _ in range(self.n_blocks)])
+
+ def _init_transformer_block(self) -> nn.Module:
+ return LiGRLayer(
+ self.n_factors,
+ self.n_heads,
+ self.dropout_rate,
+ self.ff_factors_multiplier,
+ bias_in_ff=self.bias_in_ff,
+ ff_activation=self.ff_activation,
+ )
+
+ def forward(
+ self,
+ seqs: torch.Tensor,
+ timeline_mask: torch.Tensor,
+ attn_mask: tp.Optional[torch.Tensor],
+ key_padding_mask: tp.Optional[torch.Tensor],
+ **kwargs: tp.Any,
+ ) -> torch.Tensor:
+ """
+ Forward pass through transformer blocks.
+
+ Parameters
+ ----------
+ seqs: torch.Tensor
+ User sequences of item embeddings.
+ timeline_mask: torch.Tensor
+ Mask indicating padding elements.
+ attn_mask: torch.Tensor, optional
+ Optional mask to use in forward pass of multi-head attention as `attn_mask`.
+ key_padding_mask: torch.Tensor, optional
+ Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
+
+
+ Returns
+ -------
+ torch.Tensor
+ User sequences passed through transformer layers.
+ """
+ for block_idx in range(self.n_blocks):
+ seqs = self.transformer_blocks[block_idx](seqs, attn_mask, key_padding_mask)
+ return seqs
diff --git a/rectools/models/nn/transformers/net_blocks.py b/rectools/models/nn/transformers/net_blocks.py
index 947b278b..b3ed5697 100644
--- a/rectools/models/nn/transformers/net_blocks.py
+++ b/rectools/models/nn/transformers/net_blocks.py
@@ -33,14 +33,18 @@ class PointWiseFeedForward(nn.Module):
Probability of a hidden unit to be zeroed.
activation: torch.nn.Module
Activation function module.
+ bias: bool, default ``True``
+ If ``True``, add bias to linear layers.
"""
- def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module) -> None:
+ def __init__(
+ self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module, bias: bool = True
+ ) -> None:
super().__init__()
- self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff)
+ self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias)
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
self.ff_activation = activation
- self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors)
+ self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias)
def forward(self, seqs: torch.Tensor) -> torch.Tensor:
"""
@@ -61,6 +65,92 @@ def forward(self, seqs: torch.Tensor) -> torch.Tensor:
return fin
+class SwigluFeedForward(nn.Module):
+ """
+ Feed-Forward network to introduce nonlinearity into the transformer model.
+ This implementation is based on FuXi and LLama SwigLU https://arxiv.org/pdf/2502.03036,
+ LiGR https://arxiv.org/pdf/2502.03417
+
+ Parameters
+ ----------
+ n_factors : int
+ Latent embeddings size.
+ n_factors_ff : int
+ How many hidden units to use in the network.
+ dropout_rate : float
+ Probability of a hidden unit to be zeroed.
+ bias: bool, default ``True``
+ If ``True``, add bias to linear layers.
+ """
+
+ def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, bias: bool = True) -> None:
+ super().__init__()
+ self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias=bias)
+ self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
+ self.ff_activation = torch.nn.SiLU()
+ self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias=bias)
+ self.ff_linear_3 = nn.Linear(n_factors, n_factors_ff, bias=bias)
+
+ def forward(self, seqs: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass.
+
+ Parameters
+ ----------
+ seqs : torch.Tensor
+ User sequences of item embeddings.
+
+ Returns
+ -------
+ torch.Tensor
+ User sequence that passed through all layers.
+ """
+ output = self.ff_activation(self.ff_linear_1(seqs)) * self.ff_linear_3(seqs)
+ fin = self.ff_linear_2(self.ff_dropout_1(output))
+ return fin
+
+
+def init_feed_forward(
+ n_factors: int, ff_factors_multiplier: int, dropout_rate: float, ff_activation: str, bias: bool = True
+) -> nn.Module:
+ """
+ Initialise Feed-Forward network with one of activation functions: "swiglu", "relu", "gelu".
+
+ Parameters
+ ----------
+ n_factors : int
+ Latent embeddings size.
+ ff_factors_multiplier : int
+ How many hidden units to use in the network.
+ dropout_rate : float
+ Probability of a hidden unit to be zeroed.
+ ff_activation : {"swiglu", "relu", "gelu"}
+ Activation function to use.
+ bias: bool, default ``True``
+ If ``True``, add bias to linear layers.
+
+ Returns
+ -------
+ nn.Module
+ Feed-Forward network.
+ """
+ if ff_activation == "swiglu":
+ return SwigluFeedForward(n_factors, n_factors * ff_factors_multiplier, dropout_rate, bias=bias)
+ if ff_activation == "gelu":
+ return PointWiseFeedForward(
+ n_factors, n_factors * ff_factors_multiplier, dropout_rate, activation=torch.nn.GELU(), bias=bias
+ )
+ if ff_activation == "relu":
+ return PointWiseFeedForward(
+ n_factors,
+ n_factors * ff_factors_multiplier,
+ dropout_rate,
+ activation=torch.nn.ReLU(),
+ bias=bias,
+ )
+ raise ValueError(f"Unsupported ff_activation: {ff_activation}")
+
+
class TransformerLayersBase(nn.Module):
"""Base class for transformer layers."""
diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py
index 9cb6e658..d1040b89 100644
--- a/tests/models/nn/transformers/test_sasrec.py
+++ b/tests/models/nn/transformers/test_sasrec.py
@@ -33,6 +33,7 @@
TrainerCallable,
TransformerLightningModule,
)
+from rectools.models.nn.transformers.ligr import LiGRLayers
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
@@ -760,6 +761,89 @@ def test_torch_model(self, dataset: Dataset) -> None:
model.fit(dataset)
assert isinstance(model.torch_model, TransformerTorchBackbone)
+ @pytest.mark.parametrize(
+ "activation,filter_viewed,expected,",
+ (
+ (
+ "swiglu",
+ True,
+ pd.DataFrame(
+ {
+ Columns.User: [10, 10, 30, 30, 30, 40, 40, 40],
+ Columns.Item: [17, 15, 17, 13, 14, 13, 12, 14],
+ Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3],
+ }
+ ),
+ ),
+ (
+ "gelu",
+ True,
+ pd.DataFrame(
+ {
+ Columns.User: [10, 10, 30, 30, 30, 40, 40, 40],
+ Columns.Item: [17, 15, 17, 13, 14, 13, 12, 14],
+ Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3],
+ }
+ ),
+ ),
+ (
+ "relu",
+ True,
+ pd.DataFrame(
+ {
+ Columns.User: [10, 10, 30, 30, 30, 40, 40, 40],
+ Columns.Item: [17, 15, 17, 13, 14, 13, 12, 14],
+ Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3],
+ }
+ ),
+ ),
+ ),
+ )
+ def test_ligr_layers(
+ self,
+ activation: str,
+ dataset: Dataset,
+ filter_viewed: bool,
+ expected: pd.DataFrame,
+ get_trainer_func: TrainerCallable,
+ ) -> None:
+ model = SASRecModel(
+ transformer_layers_type=LiGRLayers,
+ transformer_layers_kwargs={
+ "ff_factors_multiplier": 1,
+ "ff_activation": activation,
+ "bias_in_ff": True,
+ },
+ get_trainer_func=get_trainer_func,
+ n_factors=32,
+ n_blocks=2,
+ session_max_len=3,
+ lr=0.001,
+ batch_size=4,
+ epochs=2,
+ deterministic=True,
+ item_net_block_types=(IdEmbeddingsItemNet,),
+ similarity_module_type=DistanceSimilarityModule,
+ )
+ model.fit(dataset=dataset)
+ users = np.array([10, 30, 40])
+ actual = model.recommend(users=users, dataset=dataset, k=3, filter_viewed=filter_viewed)
+ pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected)
+ pd.testing.assert_frame_equal(
+ actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True),
+ actual,
+ )
+
+ def test_raises_when_activation_is_not_supported(self, dataset: Dataset) -> None:
+ model = SASRecModel(
+ transformer_layers_type=LiGRLayers,
+ transformer_layers_kwargs={
+ "ff_activation": "not_supported_activation",
+ },
+ )
+ with pytest.raises(ValueError):
+ model.fit(dataset)
+
class TestSASRecDataPreparator: