diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..ed1fe26 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,47 @@ +name: Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +concurrency: + group: >- + ${{ github.workflow }}-${{ + github.event.pull_request.head.repo.full_name || github.repository + }}-${{ github.head_ref || github.ref_name }} + cancel-in-progress: true + +jobs: + test: + name: Run unit tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -r requirements.txt + + - name: Install package + run: pip install -e . + + - name: Run tests + run: python -m unittest discover -s tests -v diff --git a/lekcut/attacut.py b/lekcut/attacut.py index 9d22a81..4937cad 100644 --- a/lekcut/attacut.py +++ b/lekcut/attacut.py @@ -203,6 +203,8 @@ def _make_feature(self, txt: str): return characters, features def tokenize(self, text: str) -> List[str]: + if not text: + return [] tokens, features = self._make_feature(text) logits = self.model.run(None, {"input": features})[0] preds = (_sigmoid(logits) > 0.5).astype(int) @@ -232,6 +234,8 @@ def _make_feature(self, txt: str): return characters, features def tokenize(self, text: str) -> List[str]: + if not text: + return [] tokens, features = self._make_feature(text) logits = self.model.run(None, {"input": features})[0] preds = (_sigmoid(logits) > 0.5).astype(int) diff --git a/lekcut/deepcut.py b/lekcut/deepcut.py index 1c09243..8f14195 100644 --- a/lekcut/deepcut.py +++ b/lekcut/deepcut.py @@ -138,6 +138,8 @@ def load_model(self, path: str, providers: List[str]=None): self.model = ort.InferenceSession(self.path, providers=providers) def tokenize(self, text: str) -> List[str]: + if not text: + return [] self.x_char, self.x_type = create_feature_array(text, n_pad=self.n_pad) self.x_char = self.x_char.astype(np.float32) self.x_type= self.x_type.astype(np.float32) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_lekcut.py b/tests/test_lekcut.py new file mode 100644 index 0000000..34d6f97 --- /dev/null +++ b/tests/test_lekcut.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +"""Unit tests for the LEKCut Thai word tokenization library.""" +import unittest + +from lekcut import word_tokenize + + +class TestWordTokenizeDeepcut(unittest.TestCase): + """Tests for the default deepcut model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="deepcut") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_known_output(self): + result = word_tokenize("ทดสอบการตัดคำ", model="deepcut") + self.assertEqual(result, ["ทดสอบ", "การ", "ตัด", "คำ"]) + + def test_empty_string(self): + result = word_tokenize("", model="deepcut") + self.assertEqual(result, []) + + def test_single_word(self): + result = word_tokenize("สวัสดี", model="deepcut") + self.assertIsInstance(result, list) + self.assertEqual("".join(result), "สวัสดี") + + def test_with_spaces(self): + result = word_tokenize("สวัสดี ครับ", model="deepcut") + self.assertIsInstance(result, list) + self.assertEqual("".join(result), "สวัสดี ครับ") + + def test_default_model(self): + """word_tokenize defaults to deepcut.""" + result = word_tokenize("ทดสอบการตัดคำ") + self.assertEqual(result, ["ทดสอบ", "การ", "ตัด", "คำ"]) + + +class TestWordTokenizeAttacutSC(unittest.TestCase): + """Tests for the attacut-sc model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="attacut-sc") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_empty_string(self): + result = word_tokenize("", model="attacut-sc") + self.assertEqual(result, []) + + def test_output_joins_to_input(self): + text = "ภาษาไทยสวยงาม" + result = word_tokenize(text, model="attacut-sc") + self.assertEqual("".join(result), text) + + +class TestWordTokenizeAttacutC(unittest.TestCase): + """Tests for the attacut-c model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="attacut-c") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_empty_string(self): + result = word_tokenize("", model="attacut-c") + self.assertEqual(result, []) + + def test_output_joins_to_input(self): + text = "ภาษาไทยสวยงาม" + result = word_tokenize(text, model="attacut-c") + self.assertEqual("".join(result), text) + + +class TestWordTokenizeOskut(unittest.TestCase): + """Tests for the oskut model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="oskut") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_empty_string(self): + result = word_tokenize("", model="oskut") + self.assertIsInstance(result, list) + self.assertEqual(result, []) + + def test_output_joins_to_input(self): + text = "ภาษาไทยสวยงาม" + result = word_tokenize(text, model="oskut") + self.assertEqual("".join(result), text) + + +class TestWordTokenizeSefrWs1000(unittest.TestCase): + """Tests for the sefr-ws1000 model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="sefr-ws1000") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_empty_string(self): + result = word_tokenize("", model="sefr-ws1000") + self.assertEqual(result, []) + + def test_output_joins_to_input(self): + text = "ภาษาไทยสวยงาม" + result = word_tokenize(text, model="sefr-ws1000") + self.assertEqual("".join(result), text) + + +class TestWordTokenizeSefrTnhc(unittest.TestCase): + """Tests for the sefr-tnhc model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="sefr-tnhc") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_empty_string(self): + result = word_tokenize("", model="sefr-tnhc") + self.assertEqual(result, []) + + def test_output_joins_to_input(self): + text = "ภาษาไทยสวยงาม" + result = word_tokenize(text, model="sefr-tnhc") + self.assertEqual("".join(result), text) + + +class TestWordTokenizeSefrBest(unittest.TestCase): + """Tests for the sefr-best model.""" + + def test_basic_tokenization(self): + result = word_tokenize("ทดสอบการตัดคำ", model="sefr-best") + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + self.assertEqual("".join(result), "ทดสอบการตัดคำ") + + def test_empty_string(self): + result = word_tokenize("", model="sefr-best") + self.assertEqual(result, []) + + def test_output_joins_to_input(self): + text = "ภาษาไทยสวยงาม" + result = word_tokenize(text, model="sefr-best") + self.assertEqual("".join(result), text) + + +class TestWordTokenizeErrorHandling(unittest.TestCase): + """Tests for error handling in word_tokenize.""" + + def test_unsupported_model_raises(self): + with self.assertRaises(NotImplementedError): + word_tokenize("ทดสอบ", model="unknown-model") + + +if __name__ == "__main__": + unittest.main()