diff --git a/subwiz/model.py b/subwiz/model.py index ddcade2..5295c9b 100644 --- a/subwiz/model.py +++ b/subwiz/model.py @@ -387,6 +387,7 @@ def _trim_subdomains( self, sequences: torch.Tensor, apex_unpadded_position: int, + num_initial_pad_tokens: int, num_tokens_generated: int, ) -> torch.Tensor: if num_tokens_generated == 0: @@ -394,8 +395,8 @@ def _trim_subdomains( trimming_position = ( apex_unpadded_position - if num_tokens_generated > apex_unpadded_position - else num_tokens_generated + if num_tokens_generated >= num_initial_pad_tokens + else 0 ) sequences = torch.cat( @@ -453,9 +454,9 @@ def generate( apex_padded_position = (sequences == self.delim_token).nonzero(as_tuple=True)[ 1 - ][0] + ][0] + 1 - apex_unpadded_position = 1 + apex_padded_position - num_initial_pad_tokens + apex_unpadded_position = apex_padded_position - num_initial_pad_tokens probabilities = torch.tensor([1.0], device=self.device) @@ -473,6 +474,7 @@ def generate( sequences = self._trim_subdomains( sequences, apex_unpadded_position=apex_unpadded_position, + num_initial_pad_tokens=num_initial_pad_tokens, num_tokens_generated=i, )