Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions augmentex/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,13 @@ def __text2emoji(self, word: str) -> str:
Returns:
str: Emoji that matches this word.
"""
word = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;-]+", word)
words = self.text2emoji_map.get(word[0].lower(), [word[0]])
word[0] = np.random.choice(words)
tokens = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;-]+", word)
if not tokens:
return word
words = self.text2emoji_map.get(tokens[0].lower(), [tokens[0]])
tokens[0] = np.random.choice(words)

return "".join(word)
return "".join(tokens)

def __split(self, word: str) -> str:
"""Divides a word character-by-character.
Expand All @@ -130,11 +132,13 @@ def __replace(self, word: str) -> str:
Returns:
str: A misspelled word.
"""
word = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;]+", word)
word_probas = self.orfo_dict.get(word[0].lower(), [[word[0]], [1.0]])
word[0] = np.random.choice(word_probas[0], p=word_probas[1])
tokens = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;]+", word)
if not tokens:
return word
word_probas = self.orfo_dict.get(tokens[0].lower(), [[tokens[0]], [1.0]])
tokens[0] = np.random.choice(word_probas[0], p=word_probas[1])

return "".join(word)
return "".join(tokens)

def __delete(self) -> str:
"""Deletes a random word.
Expand Down
Empty file added tests/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions tests/test_word_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Regression tests for WordAug — covers IndexError on unmatchable tokens."""

import unittest

from augmentex import WordAug


class TestWordAugReplace(unittest.TestCase):
"""Tests for the 'replace' action."""

def setUp(self):
self.aug = WordAug(
unit_prob=1.0,
min_aug=1,
max_aug=5,
lang="rus",
platform="pc",
random_seed=42,
)

def test_replace_does_not_crash_on_parentheses(self):
"""Regression for #20: tokens like '(' don't match the regex."""
text = "один из партнёров ( партнёршу)"
# Should not raise IndexError
result = self.aug.augment(text=text, action="replace")
self.assertIsInstance(result, str)

def test_replace_stress_with_original_issue_text(self):
"""Exact reproduction from issue #20 — 100 iterations."""
text = (
"это когда в отношениях, один из партнёров насилует и истязает "
"своего партнёра ( партнёршу) бывает абъюзив и по отношению "
"родителей к своим детям"
)
for _ in range(100):
result = self.aug.augment(text=text, action="replace")
self.assertIsInstance(result, str)

def test_replace_pure_special_chars(self):
"""Tokens consisting entirely of unmatchable characters."""
text = "hello — world"
result = self.aug.augment(text=text, action="replace")
self.assertIsInstance(result, str)


class TestWordAugText2Emoji(unittest.TestCase):
"""Tests for the 'text2emoji' action — same bug pattern as replace."""

def setUp(self):
self.aug = WordAug(
unit_prob=1.0,
min_aug=1,
max_aug=5,
lang="rus",
platform="pc",
random_seed=42,
)

def test_text2emoji_does_not_crash_on_special_chars(self):
text = "привет — мир"
result = self.aug.augment(text=text, action="text2emoji")
self.assertIsInstance(result, str)

def test_text2emoji_with_parentheses(self):
text = "слово ( другое)"
result = self.aug.augment(text=text, action="text2emoji")
self.assertIsInstance(result, str)


if __name__ == "__main__":
unittest.main()