diff --git a/lingbot_map/layers/rope.py b/lingbot_map/layers/rope.py index 7f44e31..a50a562 100644 --- a/lingbot_map/layers/rope.py +++ b/lingbot_map/layers/rope.py @@ -176,8 +176,14 @@ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor feature_dim = tokens.size(-1) // 2 # Get frequency components - max_position = int(positions.max()) + 1 - cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + max_position = (positions.max() + 1).to(dtype=torch.int64) + +cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, + max_position.item(), + tokens.device, + tokens.dtype, +) # Split features for vertical and horizontal processing vertical_features, horizontal_features = tokens.chunk(2, dim=-1)