From cab85d9a0dc92840c5a8e9d327f7491a5ba9fda5 Mon Sep 17 00:00:00 2001 From: EduardoTerres Date: Mon, 10 Nov 2025 16:37:07 +0100 Subject: [PATCH 1/2] Fix subdomain trimming position for when there is padding --- subwiz/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/subwiz/model.py b/subwiz/model.py index ddcade2..97bd25d 100644 --- a/subwiz/model.py +++ b/subwiz/model.py @@ -395,7 +395,7 @@ def _trim_subdomains( trimming_position = ( apex_unpadded_position if num_tokens_generated > apex_unpadded_position - else num_tokens_generated + else 0 ) sequences = torch.cat( From 049cfdf1687eada133d259de0719f5c115465883 Mon Sep 17 00:00:00 2001 From: EduardoTerres Date: Mon, 10 Nov 2025 16:39:37 +0100 Subject: [PATCH 2/2] Fix subdomain trimming cutoff --- subwiz/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/subwiz/model.py b/subwiz/model.py index 97bd25d..5295c9b 100644 --- a/subwiz/model.py +++ b/subwiz/model.py @@ -387,6 +387,7 @@ def _trim_subdomains( self, sequences: torch.Tensor, apex_unpadded_position: int, + num_initial_pad_tokens: int, num_tokens_generated: int, ) -> torch.Tensor: if num_tokens_generated == 0: @@ -394,7 +395,7 @@ def _trim_subdomains( trimming_position = ( apex_unpadded_position - if num_tokens_generated > apex_unpadded_position + if num_tokens_generated >= num_initial_pad_tokens else 0 ) @@ -453,9 +454,9 @@ def generate( apex_padded_position = (sequences == self.delim_token).nonzero(as_tuple=True)[ 1 - ][0] + ][0] + 1 - apex_unpadded_position = 1 + apex_padded_position - num_initial_pad_tokens + apex_unpadded_position = apex_padded_position - num_initial_pad_tokens probabilities = torch.tensor([1.0], device=self.device) @@ -473,6 +474,7 @@ def generate( sequences = self._trim_subdomains( sequences, apex_unpadded_position=apex_unpadded_position, + num_initial_pad_tokens=num_initial_pad_tokens, num_tokens_generated=i, )