From 0c653a03b31392cffd953037717735d4dfb5b6c8 Mon Sep 17 00:00:00 2001 From: Temrjan Date: Wed, 25 Mar 2026 09:52:37 +0500 Subject: [PATCH] fix: guard against empty regex matches in WordAug to prevent IndexError __replace() and __text2emoji() crash with IndexError when a token contains only characters not matched by the regex (e.g. parentheses, em-dashes). Return the original token unchanged when re.findall() produces an empty list. Added regression tests reproducing the exact scenario from issue #20. Closes #20 Co-Authored-By: Claude Opus 4.6 (1M context) --- augmentex/word.py | 20 +++++++----- tests/__init__.py | 0 tests/test_word_aug.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 8 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_word_aug.py diff --git a/augmentex/word.py b/augmentex/word.py index 4f44b02..10a99fa 100644 --- a/augmentex/word.py +++ b/augmentex/word.py @@ -102,11 +102,13 @@ def __text2emoji(self, word: str) -> str: Returns: str: Emoji that matches this word. """ - word = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;-]+", word) - words = self.text2emoji_map.get(word[0].lower(), [word[0]]) - word[0] = np.random.choice(words) + tokens = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;-]+", word) + if not tokens: + return word + words = self.text2emoji_map.get(tokens[0].lower(), [tokens[0]]) + tokens[0] = np.random.choice(words) - return "".join(word) + return "".join(tokens) def __split(self, word: str) -> str: """Divides a word character-by-character. @@ -130,11 +132,13 @@ def __replace(self, word: str) -> str: Returns: str: A misspelled word. """ - word = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;]+", word) - word_probas = self.orfo_dict.get(word[0].lower(), [[word[0]], [1.0]]) - word[0] = np.random.choice(word_probas[0], p=word_probas[1]) + tokens = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;]+", word) + if not tokens: + return word + word_probas = self.orfo_dict.get(tokens[0].lower(), [[tokens[0]], [1.0]]) + tokens[0] = np.random.choice(word_probas[0], p=word_probas[1]) - return "".join(word) + return "".join(tokens) def __delete(self) -> str: """Deletes a random word. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_word_aug.py b/tests/test_word_aug.py new file mode 100644 index 0000000..ed73532 --- /dev/null +++ b/tests/test_word_aug.py @@ -0,0 +1,71 @@ +"""Regression tests for WordAug — covers IndexError on unmatchable tokens.""" + +import unittest + +from augmentex import WordAug + + +class TestWordAugReplace(unittest.TestCase): + """Tests for the 'replace' action.""" + + def setUp(self): + self.aug = WordAug( + unit_prob=1.0, + min_aug=1, + max_aug=5, + lang="rus", + platform="pc", + random_seed=42, + ) + + def test_replace_does_not_crash_on_parentheses(self): + """Regression for #20: tokens like '(' don't match the regex.""" + text = "один из партнёров ( партнёршу)" + # Should not raise IndexError + result = self.aug.augment(text=text, action="replace") + self.assertIsInstance(result, str) + + def test_replace_stress_with_original_issue_text(self): + """Exact reproduction from issue #20 — 100 iterations.""" + text = ( + "это когда в отношениях, один из партнёров насилует и истязает " + "своего партнёра ( партнёршу) бывает абъюзив и по отношению " + "родителей к своим детям" + ) + for _ in range(100): + result = self.aug.augment(text=text, action="replace") + self.assertIsInstance(result, str) + + def test_replace_pure_special_chars(self): + """Tokens consisting entirely of unmatchable characters.""" + text = "hello — world" + result = self.aug.augment(text=text, action="replace") + self.assertIsInstance(result, str) + + +class TestWordAugText2Emoji(unittest.TestCase): + """Tests for the 'text2emoji' action — same bug pattern as replace.""" + + def setUp(self): + self.aug = WordAug( + unit_prob=1.0, + min_aug=1, + max_aug=5, + lang="rus", + platform="pc", + random_seed=42, + ) + + def test_text2emoji_does_not_crash_on_special_chars(self): + text = "привет — мир" + result = self.aug.augment(text=text, action="text2emoji") + self.assertIsInstance(result, str) + + def test_text2emoji_with_parentheses(self): + text = "слово ( другое)" + result = self.aug.augment(text=text, action="text2emoji") + self.assertIsInstance(result, str) + + +if __name__ == "__main__": + unittest.main()