From 4bcc11cd9f7ef9e955099f3ff471c0bff99174c2 Mon Sep 17 00:00:00 2001 From: Lakmini Senanayake <35813826+LakminiSenanayake@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:46:06 +0900 Subject: [PATCH] handle edge cases in balanced_truncate https://github.com/thunlp/OpenPrompt/issues/327 --- openprompt/plms/utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/openprompt/plms/utils.py b/openprompt/plms/utils.py index cca3206a..d7c12e4e 100644 --- a/openprompt/plms/utils.py +++ b/openprompt/plms/utils.py @@ -88,18 +88,30 @@ def balanced_truncate(input_dict: Dict, num_tokens_to_truncate: int=0) -> Dict: '''truncate the inputs with balance, number of cut tokens is proportional to the part's length. ''' - shortenable_lens = [len(parts) if parts[0]==1 else 0 - for parts in input_dict['shortenable_ids']] + # Handling empty lists in 'shortenable_ids' + shortenable_lens = [len(parts) if parts and parts[0] == 1 else 0 + for parts in input_dict['shortenable_ids']] total_shortenable_len = sum(shortenable_lens) - num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate - for part_len in shortenable_lens] + # Handle empty truncation cases + if total_shortenable_len == 0: + num_tokens_to_truncate_each_part = [0] * len(shortenable_lens) + else: + num_tokens_to_truncate_each_part = [ + part_len / total_shortenable_len * num_tokens_to_truncate + for part_len in shortenable_lens + ] + round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate) truncated_example = defaultdict(list) for key in input_dict: parts = input_dict[key] for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts): - truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part]) + truncate_len = max(len(part) - num_tokens_to_truncate_part, 0) + truncated_example[key].append(part[:truncate_len]) + # Filtering out empty sequences + for key in truncated_example: + truncated_example[key] = [part for part in truncated_example[key] if part] return truncated_example @staticmethod @@ -155,6 +167,7 @@ def padding(input_dict: Dict, max_len: int, pad_id_for_inputs: int=0, pad_id_for_others: int=0) -> None: for key, value in input_dict.items(): if (len(input_dict[key]) > max_len): + continue raise ValueError(f'''Truncated seq length of '{key}' still greater than max length {max_len}."\ "One possible reason is that no enough shortenable parts in template. Try adding {{"shortenable": "True"}} property. ''')