diff --git a/TODO.md b/TODO.md index 0fb43ab2..95b3fe5f 100644 --- a/TODO.md +++ b/TODO.md @@ -15,6 +15,10 @@ Items marked with :fire: are high priority. - Medium-term typing direction: replace runtime-injected reverse accessors with explicit reverse relationship declarations in model classes so reverse-side usage is mypy-friendly without casts or TYPE_CHECKING hacks. +- Short-term typing upgrade for user code: improve library type hints for + dynamic ORM relationship APIs (reverse FK descriptors, M2M managers, and + prefetched relation result types) to reduce required casts in normal + application code and tests. - Registry lifetime: global registry can cause cross-talk when models are defined repeatedly in one process (e.g., tests). Current mitigation exists via `ModelRegistry.reset()` and snapshot/restore helpers; longer-term option: diff --git a/pyproject.toml b/pyproject.toml index 3a9116ba..24d2d4ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,6 @@ convention = "google" [tool.ruff.lint.extend-per-file-ignores] "tests/**/*.py" = [ "S101", # we can (and MUST!) use 'assert' in test files. - "ANN001", # annotations for fixtures are sometimes a pain for test files "ARG00", # test fixtures often are not directly used "PLR2004", # magic numbers are often used in test files "SLF001", # sometimes we need to test private methods @@ -155,10 +154,6 @@ plugins = ["pydantic.mypy"] python_version = "3.9" exclude = ["docs"] -[[tool.mypy.overrides]] -disable_error_code = ["method-assign", "no-untyped-def", "attr-defined"] -module = "tests.*" - [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" diff --git a/tests/conftest.py b/tests/conftest.py index 8c601283..2da955a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,19 +15,20 @@ if TYPE_CHECKING: from collections.abc import Generator + from pathlib import Path memory_db = ":memory:" @pytest.hookimpl(tryfirst=True) -def pytest_configure(config) -> None: +def pytest_configure(config: pytest.Config) -> None: """Clear the screen before running tests.""" os.system("cls" if os.name == "nt" else "clear") # noqa: S605 @contextmanager -def not_raises(exception) -> Generator[None, Any, None]: +def not_raises(exception: type[BaseException]) -> Generator[None, Any, None]: """Fake a pytest.raises context manager that does not raise an exception. Use: `with not_raises(Exception):` @@ -198,7 +199,7 @@ def db_mock_complex_debug() -> SqliterDB: @pytest.fixture -def temp_db_path(tmp_path) -> str: +def temp_db_path(tmp_path: Path) -> str: """Fixture to create a temporary database file path.""" return str(tmp_path / "test_db.sqlite") diff --git a/tests/test_advanced_filters.py b/tests/test_advanced_filters.py index e9562715..01cb3326 100644 --- a/tests/test_advanced_filters.py +++ b/tests/test_advanced_filters.py @@ -2,31 +2,37 @@ import pytest +from sqliter.sqliter import SqliterDB + from .conftest import PersonModel class TestAdvancedFilters: """Test class containing methods to test advanced filter capabilities.""" - def test_filter_with_gt_condition(self, db_mock_adv) -> None: + def test_filter_with_gt_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with greater than condition.""" # Filter where age > 25 results = db_mock_adv.select(PersonModel).filter(age__gt=25).fetch_all() assert len(results) == 2 - assert all(result.age > 25 for result in results) + assert all( + (age := result.age) is not None and age > 25 for result in results + ) assert {result.name for result in results} == {"Bob", "Charlie"} - def test_filter_with_lt_condition(self, db_mock_adv) -> None: + def test_filter_with_lt_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with less than condition.""" # Filter where age < 35 results = db_mock_adv.select(PersonModel).filter(age__lt=35).fetch_all() assert len(results) == 2 - assert all(result.age < 35 for result in results) + assert all( + (age := result.age) is not None and age < 35 for result in results + ) assert {result.name for result in results} == {"Alice", "Bob"} - def test_filter_with_gte_condition(self, db_mock_adv) -> None: + def test_filter_with_gte_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with greater than or equal to condition.""" # Filter where age >= 30 results = ( @@ -34,10 +40,12 @@ def test_filter_with_gte_condition(self, db_mock_adv) -> None: ) assert len(results) == 2 - assert all(result.age >= 30 for result in results) + assert all( + (age := result.age) is not None and age >= 30 for result in results + ) assert {result.name for result in results} == {"Bob", "Charlie"} - def test_filter_with_lte_condition(self, db_mock_adv) -> None: + def test_filter_with_lte_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with less than or equal to condition.""" # Filter where age <= 30 results = ( @@ -45,10 +53,12 @@ def test_filter_with_lte_condition(self, db_mock_adv) -> None: ) assert len(results) == 2 - assert all(result.age <= 30 for result in results) + assert all( + (age := result.age) is not None and age <= 30 for result in results + ) assert {result.name for result in results} == {"Alice", "Bob"} - def test_filter_with_eq_condition(self, db_mock_adv) -> None: + def test_filter_with_eq_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with equal to condition.""" # Filter where age == 30 results = db_mock_adv.select(PersonModel).filter(age__eq=30).fetch_all() @@ -57,7 +67,7 @@ def test_filter_with_eq_condition(self, db_mock_adv) -> None: assert results[0].age == 30 assert results[0].name == "Bob" - def test_filter_with_ne_condition(self, db_mock_adv) -> None: + def test_filter_with_ne_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with not equal to condition.""" # Filter where age != 30 results = db_mock_adv.select(PersonModel).filter(age__ne=30).fetch_all() @@ -66,7 +76,9 @@ def test_filter_with_ne_condition(self, db_mock_adv) -> None: assert all(result.age != 30 for result in results) assert {result.name for result in results} == {"Alice", "Charlie"} - def test_filter_with_gt_and_lt_combined(self, db_mock_adv) -> None: + def test_filter_with_gt_and_lt_combined( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with combined greater than and less than conditions.""" db_mock_adv.insert(PersonModel(name="David", age=40)) @@ -78,10 +90,15 @@ def test_filter_with_gt_and_lt_combined(self, db_mock_adv) -> None: ) assert len(results) == 2 - assert all(25 < result.age < 40 for result in results) + assert all( + (age := result.age) is not None and 25 < age < 40 + for result in results + ) assert {result.name for result in results} == {"Bob", "Charlie"} - def test_filter_with_gt_and_lte_combined(self, db_mock_adv) -> None: + def test_filter_with_gt_and_lte_combined( + self, db_mock_adv: SqliterDB + ) -> None: """Test with combined greater than and less than or equal conditions.""" db_mock_adv.insert(PersonModel(name="David", age=40)) @@ -93,10 +110,15 @@ def test_filter_with_gt_and_lte_combined(self, db_mock_adv) -> None: ) assert len(results) == 2 - assert all(25 < result.age <= 35 for result in results) + assert all( + (age := result.age) is not None and 25 < age <= 35 + for result in results + ) assert {result.name for result in results} == {"Bob", "Charlie"} - def test_filter_with_is_null_condition(self, db_mock_adv) -> None: + def test_filter_with_is_null_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with IS NULL condition.""" db_mock_adv.insert(PersonModel(name="David", age=None)) # Filter where age is NULL @@ -108,7 +130,9 @@ def test_filter_with_is_null_condition(self, db_mock_adv) -> None: assert results[0].age is None assert results[0].name == "David" - def test_filter_with_is_not_null_condition(self, db_mock_adv) -> None: + def test_filter_with_is_not_null_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with IS NOT NULL condition.""" db_mock_adv.insert(PersonModel(name="David", age=None)) # Filter where age is NOT NULL @@ -126,7 +150,7 @@ def test_filter_with_is_not_null_condition(self, db_mock_adv) -> None: "Charlie", } - def test_filter_with_in_condition(self, db_mock_adv) -> None: + def test_filter_with_in_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with IN condition.""" # Filter where age IN (25, 35) results = ( @@ -137,7 +161,7 @@ def test_filter_with_in_condition(self, db_mock_adv) -> None: assert all(result.age in [25, 35] for result in results) assert {result.name for result in results} == {"Alice", "Charlie"} - def test_filter_with_not_in_condition(self, db_mock_adv) -> None: + def test_filter_with_not_in_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with NOT IN condition.""" # Filter where age NOT IN (25, 35) results = ( @@ -150,21 +174,25 @@ def test_filter_with_not_in_condition(self, db_mock_adv) -> None: assert results[0].age == 30 assert results[0].name == "Bob" - def test_filter_with_bad_in_condition(self, db_mock_adv) -> None: + def test_filter_with_bad_in_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with bad IN condition.""" with pytest.raises(TypeError, match="age requires a list") as exc_info: db_mock_adv.select(PersonModel).filter(age__in=25).fetch_all() assert str(exc_info.value) == "age requires a list for '__in'" - def test_filter_with_bad_not_in_condition(self, db_mock_adv) -> None: + def test_filter_with_bad_not_in_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with bad NOT IN condition.""" with pytest.raises(TypeError, match="age requires a list") as exc_info: db_mock_adv.select(PersonModel).filter(age__not_in=25).fetch_all() assert str(exc_info.value) == "age requires a list for '__not_in'" - def test_filter_with_starts_with_condition(self, db_mock_adv) -> None: + def test_filter_with_starts_with_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with starts with condition (default case-sensitive).""" db_mock_adv.insert(PersonModel(name="alison", age=50)) # Filter where name starts with 'A' case-sensitive @@ -179,7 +207,7 @@ def test_filter_with_starts_with_condition(self, db_mock_adv) -> None: def test_filter_with_starts_with_condition_case_insensitive( self, - db_mock_adv, + db_mock_adv: SqliterDB, ) -> None: """Test filter with starts with condition (case insensitive).""" db_mock_adv.insert(PersonModel(name="alison", age=50)) @@ -193,7 +221,9 @@ def test_filter_with_starts_with_condition_case_insensitive( assert len(results) == 2 assert results[0].name == "Alice" - def test_filter_with_bad_starts_with_condition(self, db_mock_adv) -> None: + def test_filter_with_bad_starts_with_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with bad starts with condition.""" with pytest.raises( TypeError, match="name requires a string" @@ -207,7 +237,9 @@ def test_filter_with_bad_starts_with_condition(self, db_mock_adv) -> None: == "name requires a string value for '__startswith'" ) - def test_filter_with_ends_with_condition(self, db_mock_adv) -> None: + def test_filter_with_ends_with_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with ends with condition (case sensitive).""" db_mock_adv.insert(PersonModel(name="DALE", age=2)) @@ -219,11 +251,14 @@ def test_filter_with_ends_with_condition(self, db_mock_adv) -> None: ) assert len(results) == 2 - assert all(result.name.endswith("e") for result in results) + assert all( + (name := result.name) is not None and name.endswith("e") + for result in results + ) def test_filter_with_ends_with_condition_case_insensitive( self, - db_mock_adv, + db_mock_adv: SqliterDB, ) -> None: """Test filter with ends with condition (case insensitive).""" # Filter where name ends with 'e' (case insensitive) @@ -234,9 +269,14 @@ def test_filter_with_ends_with_condition_case_insensitive( ) assert len(results) == 2 - assert all(result.name.endswith("e") for result in results) + assert all( + (name := result.name) is not None and name.endswith("e") + for result in results + ) - def test_filter_with_bad_ends_with_condition(self, db_mock_adv) -> None: + def test_filter_with_bad_ends_with_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with bad ends with condition.""" with pytest.raises( TypeError, match="name requires a string" @@ -250,7 +290,9 @@ def test_filter_with_bad_ends_with_condition(self, db_mock_adv) -> None: == "name requires a string value for '__endswith'" ) - def test_filter_with_contains_condition(self, db_mock_adv) -> None: + def test_filter_with_contains_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with contains condition (case-sensitive).""" # Add one more record for our test db_mock_adv.insert(PersonModel(name="Lianne", age=40)) @@ -290,7 +332,9 @@ def test_filter_with_contains_condition(self, db_mock_adv) -> None: ) assert len(results) == 0 - def test_filter_with_icontains_condition(self, db_mock_adv) -> None: + def test_filter_with_icontains_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with case-insensitive contains condition.""" # No need to insert new records, we'll use existing ones @@ -321,7 +365,9 @@ def test_filter_with_icontains_condition(self, db_mock_adv) -> None: assert len(results) == 2 assert {r.name for r in results} == {"Alice", "Charlie"} - def test_filter_with_bad_contains_condition(self, db_mock_adv) -> None: + def test_filter_with_bad_contains_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with bad contains condition.""" with pytest.raises( TypeError, match="name requires a string" @@ -335,7 +381,7 @@ def test_filter_with_bad_contains_condition(self, db_mock_adv) -> None: == "name requires a string value for '__contains'" ) - def test_multiple_chained_filters(self, db_mock_adv) -> None: + def test_multiple_chained_filters(self, db_mock_adv: SqliterDB) -> None: """Test multiple chained filters.""" # Insert an additional record db_mock_adv.insert(PersonModel(name="Alex", age=28)) @@ -352,7 +398,7 @@ def test_multiple_chained_filters(self, db_mock_adv) -> None: assert results[0].age == 28 def test_all_records_with_multiple_inclusive_filters( - self, db_mock_adv + self, db_mock_adv: SqliterDB ) -> None: """Test using multiple filters in same filter() call.""" results = ( @@ -368,7 +414,9 @@ def test_all_records_with_multiple_inclusive_filters( "Charlie", } - def test_name_isnull_and_notnull_filters(self, db_mock_adv) -> None: + def test_name_isnull_and_notnull_filters( + self, db_mock_adv: SqliterDB + ) -> None: """Test various filters with __isnull and __notnull.""" # Test __isnull=False results = ( @@ -424,7 +472,7 @@ def test_name_isnull_and_notnull_filters(self, db_mock_adv) -> None: assert len(results) == 1 assert results[0].name is None - def test_filter_with_like_condition(self, db_mock_adv) -> None: + def test_filter_with_like_condition(self, db_mock_adv: SqliterDB) -> None: """Test filter with LIKE condition using wildcards.""" # Filter where name matches 'A%' (starts with A) results = ( @@ -433,7 +481,7 @@ def test_filter_with_like_condition(self, db_mock_adv) -> None: assert len(results) == 1 assert results[0].name == "Alice" - def test_filter_with_like_ends_with(self, db_mock_adv) -> None: + def test_filter_with_like_ends_with(self, db_mock_adv: SqliterDB) -> None: """Test filter with LIKE condition for ends with pattern.""" # Filter where name matches '%e' (ends with e) results = ( @@ -442,7 +490,7 @@ def test_filter_with_like_ends_with(self, db_mock_adv) -> None: assert len(results) == 2 assert {r.name for r in results} == {"Alice", "Charlie"} - def test_filter_with_like_contains(self, db_mock_adv) -> None: + def test_filter_with_like_contains(self, db_mock_adv: SqliterDB) -> None: """Test filter with LIKE condition for contains pattern.""" # Filter where name matches '%li%' (contains 'li') results = ( @@ -453,7 +501,9 @@ def test_filter_with_like_contains(self, db_mock_adv) -> None: assert len(results) == 2 assert {r.name for r in results} == {"Alice", "Charlie"} - def test_filter_with_like_single_char_wildcard(self, db_mock_adv) -> None: + def test_filter_with_like_single_char_wildcard( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with LIKE condition using single character wildcard.""" # Filter where name matches '_ob' (3 chars ending in 'ob') results = ( @@ -462,7 +512,9 @@ def test_filter_with_like_single_char_wildcard(self, db_mock_adv) -> None: assert len(results) == 1 assert results[0].name == "Bob" - def test_filter_with_like_case_insensitive(self, db_mock_adv) -> None: + def test_filter_with_like_case_insensitive( + self, db_mock_adv: SqliterDB + ) -> None: """Test that LIKE is case-insensitive in SQLite by default.""" # SQLite LIKE is case-insensitive for ASCII characters results = ( @@ -471,14 +523,16 @@ def test_filter_with_like_case_insensitive(self, db_mock_adv) -> None: assert len(results) == 1 assert results[0].name == "Alice" - def test_filter_with_like_no_match(self, db_mock_adv) -> None: + def test_filter_with_like_no_match(self, db_mock_adv: SqliterDB) -> None: """Test filter with LIKE condition that matches nothing.""" results = ( db_mock_adv.select(PersonModel).filter(name__like="Z%").fetch_all() ) assert len(results) == 0 - def test_filter_with_bad_like_condition(self, db_mock_adv) -> None: + def test_filter_with_bad_like_condition( + self, db_mock_adv: SqliterDB + ) -> None: """Test filter with bad LIKE condition (non-string value).""" with pytest.raises( TypeError, match="name requires a string" diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index e7fa12a9..c2898922 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -714,7 +714,7 @@ def test_projection_query_sqlite_error_raises_record_fetch_error( .annotate(total=func.sum("amount")) ) - def raise_sqlite_error(*_args, **_kwargs) -> None: # noqa: ANN002, ANN003 + def raise_sqlite_error(*_args: object, **_kwargs: object) -> None: err = "broken" raise sqlite3.Error(err) diff --git a/tests/test_bulk_update.py b/tests/test_bulk_update.py index 7fb3433e..13695c1c 100644 --- a/tests/test_bulk_update.py +++ b/tests/test_bulk_update.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from pathlib import Path + from pytest_mock import MockerFixture + import pytest from sqliter import SqliterDB @@ -209,7 +211,7 @@ def test_update_multiple_invalid_fields(self, db: SqliterDB) -> None: assert "bar" in error_msg def test_update_sqlite_error_rolls_back_and_raises( - self, db: SqliterDB, mocker + self, db: SqliterDB, mocker: MockerFixture ) -> None: """SQLite update errors rollback and raise RecordUpdateError.""" db.insert(SimpleModel(name="test", value=10)) @@ -459,7 +461,7 @@ def test_update_pk_raises_error(self, db: SqliterDB) -> None: assert "pk" in str(exc_info.value) def test_update_auto_sets_updated_at( - self, db: SqliterDB, monkeypatch + self, db: SqliterDB, monkeypatch: pytest.MonkeyPatch ) -> None: """Bulk update auto-sets updated_at timestamp.""" fake_time = 1000000000.0 @@ -504,7 +506,7 @@ class TestUpdateWhereTimestamps: """Test updated_at behavior in update_where().""" def test_update_where_auto_sets_updated_at( - self, db: SqliterDB, monkeypatch + self, db: SqliterDB, monkeypatch: pytest.MonkeyPatch ) -> None: """update_where auto-sets updated_at timestamp.""" fake_time = 1000000000.0 diff --git a/tests/test_cache.py b/tests/test_cache.py index 70de6a3f..4cb75e70 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any +from typing import TYPE_CHECKING, Any from unittest.mock import patch import pytest @@ -11,6 +11,9 @@ from sqliter import SqliterDB from sqliter.model import BaseDBModel +if TYPE_CHECKING: + from pathlib import Path + class User(BaseDBModel): """Test model for caching tests.""" @@ -22,9 +25,9 @@ class User(BaseDBModel): class TestCacheDisabledByDefault: """Test that caching is disabled by default.""" - def test_cache_disabled_by_default(self, tmp_path) -> None: + def test_cache_disabled_by_default(self, tmp_path: Path) -> None: """Verify caching is off unless explicitly enabled.""" - db = SqliterDB(tmp_path / "test.db") + db = SqliterDB(str(tmp_path / "test.db")) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -40,50 +43,56 @@ def test_cache_disabled_by_default(self, tmp_path) -> None: class TestCacheParameterValidation: """Test validation of cache configuration parameters.""" - def test_cache_max_size_must_be_positive(self, tmp_path) -> None: + def test_cache_max_size_must_be_positive(self, tmp_path: Path) -> None: """cache_max_size must be greater than 0.""" with pytest.raises( ValueError, match="cache_max_size must be greater than 0" ): SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_size=0 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_size=0 ) with pytest.raises( ValueError, match="cache_max_size must be greater than 0" ): SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_size=-1 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_size=-1 ) - def test_cache_ttl_must_be_non_negative(self, tmp_path) -> None: + def test_cache_ttl_must_be_non_negative(self, tmp_path: Path) -> None: """cache_ttl must be non-negative.""" with pytest.raises(ValueError, match="cache_ttl must be non-negative"): - SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=-1) + SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=-1 + ) - def test_cache_max_memory_mb_must_be_positive(self, tmp_path) -> None: + def test_cache_max_memory_mb_must_be_positive(self, tmp_path: Path) -> None: """cache_max_memory_mb must be greater than 0.""" with pytest.raises( ValueError, match="cache_max_memory_mb must be greater than 0" ): SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_memory_mb=0 + str(tmp_path / "test.db"), + cache_enabled=True, + cache_max_memory_mb=0, ) with pytest.raises( ValueError, match="cache_max_memory_mb must be greater than 0" ): SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_memory_mb=-1 + str(tmp_path / "test.db"), + cache_enabled=True, + cache_max_memory_mb=-1, ) class TestCacheHitOnRepeatedQuery: """Test cache hits on repeated queries.""" - def test_cache_hit_on_repeated_query(self, tmp_path) -> None: + def test_cache_hit_on_repeated_query(self, tmp_path: Path) -> None: """Repeated queries return cached result and increment hit counter.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -111,9 +120,9 @@ def test_cache_hit_on_repeated_query(self, tmp_path) -> None: class TestGetCacheControls: """Test caching behavior for get() calls.""" - def test_get_cache_hits(self, tmp_path) -> None: + def test_get_cache_hits(self, tmp_path: Path) -> None: """get() uses cache on repeated lookups.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) user = db.insert(User(name="Alice", age=30)) @@ -131,9 +140,9 @@ def test_get_cache_hits(self, tmp_path) -> None: db.close() - def test_get_bypass_cache(self, tmp_path) -> None: + def test_get_bypass_cache(self, tmp_path: Path) -> None: """get(bypass_cache=True) skips cache read/write.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) user = db.insert(User(name="Alice", age=30)) @@ -144,9 +153,11 @@ def test_get_bypass_cache(self, tmp_path) -> None: db.close() - def test_get_cache_ttl_override(self, tmp_path) -> None: + def test_get_cache_ttl_override(self, tmp_path: Path) -> None: """get(cache_ttl=...) overrides global TTL.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=100) + db = SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=100 + ) db.create_table(User) user = db.insert(User(name="Alice", age=30)) @@ -162,9 +173,9 @@ def test_get_cache_ttl_override(self, tmp_path) -> None: db.close() - def test_get_cache_ttl_negative(self, tmp_path) -> None: + def test_get_cache_ttl_negative(self, tmp_path: Path) -> None: """get(cache_ttl=...) rejects negative values.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) with pytest.raises(ValueError, match="cache_ttl must be non-negative"): @@ -176,9 +187,9 @@ def test_get_cache_ttl_negative(self, tmp_path) -> None: class TestCacheInvalidation: """Test cache invalidation on write operations.""" - def test_cache_invalidation_on_insert(self, tmp_path) -> None: + def test_cache_invalidation_on_insert(self, tmp_path: Path) -> None: """Insert clears table cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -195,9 +206,9 @@ def test_cache_invalidation_on_insert(self, tmp_path) -> None: db.close() - def test_cache_invalidation_on_update(self, tmp_path) -> None: + def test_cache_invalidation_on_update(self, tmp_path: Path) -> None: """Update clears table cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) user = db.insert(User(name="Alice", age=30)) @@ -215,9 +226,9 @@ def test_cache_invalidation_on_update(self, tmp_path) -> None: db.close() - def test_cache_invalidation_on_delete_by_pk(self, tmp_path) -> None: + def test_cache_invalidation_on_delete_by_pk(self, tmp_path: Path) -> None: """Delete by primary key clears table cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) user = db.insert(User(name="Alice", age=30)) @@ -234,9 +245,11 @@ def test_cache_invalidation_on_delete_by_pk(self, tmp_path) -> None: db.close() - def test_cache_invalidation_on_delete_by_query(self, tmp_path) -> None: + def test_cache_invalidation_on_delete_by_query( + self, tmp_path: Path + ) -> None: """Delete by query clears table cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) db.insert(User(name="Bob", age=25)) @@ -259,9 +272,9 @@ def test_cache_invalidation_on_delete_by_query(self, tmp_path) -> None: class TestCacheClearedOnClose: """Test that cache is cleared when connection is closed.""" - def test_cache_cleared_on_close(self, tmp_path) -> None: + def test_cache_cleared_on_close(self, tmp_path: Path) -> None: """Cache cleared when connection closed.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -275,9 +288,9 @@ def test_cache_cleared_on_close(self, tmp_path) -> None: # Cache should be cleared assert len(db._cache) == 0 - def test_cache_context_manager(self, tmp_path) -> None: + def test_cache_context_manager(self, tmp_path: Path) -> None: """Cache cleared when using context manager.""" - with SqliterDB(tmp_path / "test.db", cache_enabled=True) as db: + with SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) as db: db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -292,9 +305,11 @@ def test_cache_context_manager(self, tmp_path) -> None: class TestCacheTtlExpiration: """Test TTL-based cache expiration.""" - def test_cache_ttl_expiration(self, tmp_path) -> None: + def test_cache_ttl_expiration(self, tmp_path: Path) -> None: """Entries expire after TTL.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=1) + db = SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=1 + ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -315,10 +330,10 @@ def test_cache_ttl_expiration(self, tmp_path) -> None: class TestCacheMaxSizeLru: """Test LRU eviction when cache is full.""" - def test_cache_max_size_lru(self, tmp_path) -> None: + def test_cache_max_size_lru(self, tmp_path: Path) -> None: """Oldest entries evicted when max size reached.""" db = SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_size=2 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_size=2 ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -476,9 +491,9 @@ def test_clear_cache_allows_fresh_queries(self) -> None: class TestCacheKeyVariations: """Test that different query parameters create different cache keys.""" - def test_cache_key_includes_filters(self, tmp_path) -> None: + def test_cache_key_includes_filters(self, tmp_path: Path) -> None: """Different filters create different cache entries.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) db.insert(User(name="Bob", age=25)) @@ -497,9 +512,9 @@ def test_cache_key_includes_filters(self, tmp_path) -> None: db.close() - def test_cache_key_includes_limit_offset(self, tmp_path) -> None: + def test_cache_key_includes_limit_offset(self, tmp_path: Path) -> None: """Different pagination creates different cache entries.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) db.insert(User(name="Bob", age=25)) @@ -517,9 +532,9 @@ def test_cache_key_includes_limit_offset(self, tmp_path) -> None: db.close() - def test_cache_key_includes_order_by(self, tmp_path) -> None: + def test_cache_key_includes_order_by(self, tmp_path: Path) -> None: """Different order by creates different cache entries.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) db.insert(User(name="Bob", age=25)) @@ -543,9 +558,9 @@ def test_cache_key_includes_order_by(self, tmp_path) -> None: class TestCacheEmptyResults: """Test caching of empty results.""" - def test_cache_empty_single_result(self, tmp_path) -> None: + def test_cache_empty_single_result(self, tmp_path: Path) -> None: """Empty single results are cached and retrieved from cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -572,9 +587,9 @@ def test_cache_empty_single_result(self, tmp_path) -> None: db.close() - def test_cache_empty_list_result(self, tmp_path) -> None: + def test_cache_empty_list_result(self, tmp_path: Path) -> None: """Empty list results are cached and retrieved from cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -605,9 +620,9 @@ def test_cache_empty_list_result(self, tmp_path) -> None: class TestCacheWithFields: """Test caching with field selection.""" - def test_cache_with_field_selection(self, tmp_path) -> None: + def test_cache_with_field_selection(self, tmp_path: Path) -> None: """Different field selections create different cache entries.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -630,9 +645,9 @@ def test_cache_with_field_selection(self, tmp_path) -> None: class TestCacheStatistics: """Test cache statistics tracking.""" - def test_cache_stats_initial_state(self, tmp_path) -> None: + def test_cache_stats_initial_state(self, tmp_path: Path) -> None: """Cache stats start at zero.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) stats = db.get_cache_stats() @@ -643,9 +658,9 @@ def test_cache_stats_initial_state(self, tmp_path) -> None: db.close() - def test_cache_stats_track_hits(self, tmp_path) -> None: + def test_cache_stats_track_hits(self, tmp_path: Path) -> None: """Cache stats track hits correctly.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -664,9 +679,9 @@ def test_cache_stats_track_hits(self, tmp_path) -> None: db.close() - def test_cache_stats_track_misses(self, tmp_path) -> None: + def test_cache_stats_track_misses(self, tmp_path: Path) -> None: """Cache stats track misses correctly.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -682,9 +697,9 @@ def test_cache_stats_track_misses(self, tmp_path) -> None: db.close() - def test_cache_stats_with_invalidation(self, tmp_path) -> None: + def test_cache_stats_with_invalidation(self, tmp_path: Path) -> None: """Cache stats continue after invalidation.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -708,9 +723,9 @@ def test_cache_stats_with_invalidation(self, tmp_path) -> None: db.close() - def test_cache_stats_disabled_cache(self, tmp_path) -> None: + def test_cache_stats_disabled_cache(self, tmp_path: Path) -> None: """Cache stats don't increment when cache is disabled.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=False) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=False) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -725,9 +740,11 @@ def test_cache_stats_disabled_cache(self, tmp_path) -> None: db.close() - def test_cache_stats_with_ttl_expiration(self, tmp_path) -> None: + def test_cache_stats_with_ttl_expiration(self, tmp_path: Path) -> None: """Expired cache entries count as misses.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=1) + db = SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=1 + ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -753,9 +770,9 @@ def test_cache_stats_with_ttl_expiration(self, tmp_path) -> None: db.close() - def test_cache_stats_hit_rate_calculation(self, tmp_path) -> None: + def test_cache_stats_hit_rate_calculation(self, tmp_path: Path) -> None: """Hit rate is calculated correctly.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -775,14 +792,14 @@ def test_cache_stats_hit_rate_calculation(self, tmp_path) -> None: class TestCacheMemoryLimit: """Test memory-based cache limiting.""" - def test_memory_usage_with_set_fields(self, tmp_path) -> None: + def test_memory_usage_with_set_fields(self, tmp_path: Path) -> None: """Memory usage calculation works with set fields.""" class ModelWithSet(BaseDBModel): name: str tags: set[str] - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(ModelWithSet) db.insert( ModelWithSet(name="test", tags={"python", "database", "caching"}) @@ -797,7 +814,7 @@ class ModelWithSet(BaseDBModel): db.close() - def test_memory_limit_enforcement(self, tmp_path) -> None: + def test_memory_limit_enforcement(self, tmp_path: Path) -> None: """Cache enforces memory limit by evicting entries.""" # Create a model with large fields to consume memory quickly @@ -807,7 +824,7 @@ class LargeData(BaseDBModel): # Set a very low memory limit (1MB) db = SqliterDB( - tmp_path / "test.db", + str(tmp_path / "test.db"), cache_enabled=True, cache_max_memory_mb=1, ) @@ -836,10 +853,10 @@ class LargeData(BaseDBModel): db.close() - def test_memory_usage_tracking(self, tmp_path) -> None: + def test_memory_usage_tracking(self, tmp_path: Path) -> None: """Memory usage is calculated on-demand per table.""" db = SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_memory_mb=1 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_memory_mb=1 ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -855,11 +872,11 @@ def test_memory_usage_tracking(self, tmp_path) -> None: def test_memory_tracking_cleared_on_invalidation( self, - tmp_path, + tmp_path: Path, ) -> None: """Memory usage is 0 when cache is invalidated.""" db = SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_memory_mb=1 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_memory_mb=1 ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -877,10 +894,10 @@ def test_memory_tracking_cleared_on_invalidation( db.close() - def test_memory_tracking_cleared_on_close(self, tmp_path) -> None: + def test_memory_tracking_cleared_on_close(self, tmp_path: Path) -> None: """Memory usage is 0 when connection is closed.""" db = SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_memory_mb=1 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_memory_mb=1 ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -895,10 +912,12 @@ def test_memory_tracking_cleared_on_close(self, tmp_path) -> None: # Memory usage should be 0 (cache was cleared) assert db._get_table_memory_usage(User.get_table_name()) == 0 - def test_memory_tracking_cleared_on_context_exit(self, tmp_path) -> None: + def test_memory_tracking_cleared_on_context_exit( + self, tmp_path: Path + ) -> None: """Memory usage is 0 when exiting context manager.""" with SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_memory_mb=1 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_memory_mb=1 ) as db: db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -910,11 +929,11 @@ def test_memory_tracking_cleared_on_context_exit(self, tmp_path) -> None: # After exiting context, memory usage should be 0 assert db._get_table_memory_usage(User.get_table_name()) == 0 - def test_memory_limit_with_both_limits(self, tmp_path) -> None: + def test_memory_limit_with_both_limits(self, tmp_path: Path) -> None: """Both cache_max_size and cache_max_memory_mb are respected.""" # Set both limits: 10 entries OR 1MB (whichever hit first) db = SqliterDB( - tmp_path / "test.db", + str(tmp_path / "test.db"), cache_enabled=True, cache_max_size=10, cache_max_memory_mb=1, @@ -935,10 +954,10 @@ def test_memory_limit_with_both_limits(self, tmp_path) -> None: db.close() - def test_no_memory_limit_when_none(self, tmp_path) -> None: + def test_no_memory_limit_when_none(self, tmp_path: Path) -> None: """When cache_max_memory_mb is None, only size limit applies.""" db = SqliterDB( - tmp_path / "test.db", + str(tmp_path / "test.db"), cache_enabled=True, cache_max_size=5, cache_max_memory_mb=None, @@ -964,9 +983,9 @@ def test_no_memory_limit_when_none(self, tmp_path) -> None: class TestQueryLevelBypass: """Test query-level cache bypass controls.""" - def test_bypass_cache_skips_cache_read(self, tmp_path) -> None: + def test_bypass_cache_skips_cache_read(self, tmp_path: Path) -> None: """bypass_cache() skips reading from cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1002,9 +1021,9 @@ def test_bypass_cache_skips_cache_read(self, tmp_path) -> None: db.close() - def test_bypass_cache_skips_cache_write(self, tmp_path) -> None: + def test_bypass_cache_skips_cache_write(self, tmp_path: Path) -> None: """bypass_cache() doesn't write to cache.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1021,9 +1040,9 @@ def test_bypass_cache_skips_cache_write(self, tmp_path) -> None: db.close() - def test_bypass_cache_with_filter_chain(self, tmp_path) -> None: + def test_bypass_cache_with_filter_chain(self, tmp_path: Path) -> None: """bypass_cache() works with method chaining.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1043,9 +1062,9 @@ def test_bypass_cache_with_filter_chain(self, tmp_path) -> None: db.close() - def test_bypass_cache_with_empty_result(self, tmp_path) -> None: + def test_bypass_cache_with_empty_result(self, tmp_path: Path) -> None: """bypass_cache() works with empty results.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1062,10 +1081,12 @@ def test_bypass_cache_with_empty_result(self, tmp_path) -> None: class TestQueryLevelTtl: """Test query-level TTL controls.""" - def test_query_ttl_overrides_global_ttl(self, tmp_path) -> None: + def test_query_ttl_overrides_global_ttl(self, tmp_path: Path) -> None: """Query-level TTL overrides global cache_ttl.""" # Global TTL of 10 seconds - db = SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=10) + db = SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=10 + ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1082,10 +1103,12 @@ def test_query_ttl_overrides_global_ttl(self, tmp_path) -> None: db.close() - def test_query_ttl_longer_than_global(self, tmp_path) -> None: + def test_query_ttl_longer_than_global(self, tmp_path: Path) -> None: """Query-level TTL can be longer than global TTL.""" # Global TTL of 1 second - db = SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=1) + db = SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=1 + ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1102,10 +1125,10 @@ def test_query_ttl_longer_than_global(self, tmp_path) -> None: db.close() - def test_query_ttl_without_global_ttl(self, tmp_path) -> None: + def test_query_ttl_without_global_ttl(self, tmp_path: Path) -> None: """Query-level TTL works without global TTL.""" # No global TTL - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1122,9 +1145,9 @@ def test_query_ttl_without_global_ttl(self, tmp_path) -> None: db.close() - def test_query_ttl_with_method_chaining(self, tmp_path) -> None: + def test_query_ttl_with_method_chaining(self, tmp_path: Path) -> None: """cache_ttl() works with method chaining.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1142,9 +1165,9 @@ def test_query_ttl_with_method_chaining(self, tmp_path) -> None: db.close() - def test_query_ttl_validates_non_negative(self, tmp_path) -> None: + def test_query_ttl_validates_non_negative(self, tmp_path: Path) -> None: """cache_ttl() raises ValueError for negative values.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) # Negative TTL should raise ValueError @@ -1155,10 +1178,12 @@ def test_query_ttl_validates_non_negative(self, tmp_path) -> None: def test_query_ttl_different_for_different_queries( self, - tmp_path, + tmp_path: Path, ) -> None: """Different queries can have different TTLs.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True, cache_ttl=100) + db = SqliterDB( + str(tmp_path / "test.db"), cache_enabled=True, cache_ttl=100 + ) db.create_table(User) db.insert(User(name="Alice", age=30)) db.insert(User(name="Bob", age=25)) @@ -1191,10 +1216,10 @@ class TestFetchModeCacheKey: def test_fetch_one_and_fetch_all_use_different_cache_keys( self, - tmp_path, + tmp_path: Path, ) -> None: """fetch_one() and fetch_all() should generate different cache keys.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) db.insert(User(name="Bob", age=25)) @@ -1237,9 +1262,11 @@ class TestEmptyResultCaching: being returned from cache due to truthiness-based cache hit detection. """ - def test_empty_result_from_fetch_one_is_cached(self, tmp_path) -> None: + def test_empty_result_from_fetch_one_is_cached( + self, tmp_path: Path + ) -> None: """None result from fetch_one() should be cached and retrieved.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1261,9 +1288,11 @@ def test_empty_result_from_fetch_one_is_cached(self, tmp_path) -> None: db.close() - def test_empty_result_from_fetch_all_is_cached(self, tmp_path) -> None: + def test_empty_result_from_fetch_all_is_cached( + self, tmp_path: Path + ) -> None: """Empty list [] from fetch_all() should be cached and retrieved.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1287,11 +1316,11 @@ def test_empty_result_from_fetch_all_is_cached(self, tmp_path) -> None: def test_overwriting_cache_key_updates_memory_and_lru( self, - tmp_path, + tmp_path: Path, ) -> None: """Overwriting an existing cache key updates LRU position.""" db = SqliterDB( - tmp_path / "test.db", cache_enabled=True, cache_max_size=3 + str(tmp_path / "test.db"), cache_enabled=True, cache_max_size=3 ) db.create_table(User) db.insert(User(name="Alice", age=30)) @@ -1337,10 +1366,10 @@ class TestCacheKeyErrors: def test_incomparable_filter_types_raises_error( self, - tmp_path, + tmp_path: Path, ) -> None: """Filters with incomparable types raise ValueError.""" - db = SqliterDB(tmp_path / "test.db", cache_enabled=True) + db = SqliterDB(str(tmp_path / "test.db"), cache_enabled=True) db.create_table(User) db.insert(User(name="Alice", age=30)) diff --git a/tests/test_context_manager.py b/tests/test_context_manager.py index 919330ee..629e1d65 100644 --- a/tests/test_context_manager.py +++ b/tests/test_context_manager.py @@ -1,8 +1,10 @@ """Test the context-manager functionality.""" import sqlite3 +from pathlib import Path import pytest +from pytest_mock import MockerFixture from sqliter.sqliter import SqliterDB from tests.conftest import ExampleModel @@ -11,7 +13,9 @@ class TestContextManager: """Test the context-manager functionality.""" - def test_transaction_commit_success(self, db_mock, mocker) -> None: + def test_transaction_commit_success( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that the transaction commits successfully with no exceptions.""" # Mock the connection's commit method to track the commit mock_commit = mocker.patch.object(db_mock, "conn", create=True) @@ -24,7 +28,9 @@ def test_transaction_commit_success(self, db_mock, mocker) -> None: # Ensure commit was called mock_commit.commit.assert_called_once() - def test_transaction_closes_connection(self, db_mock, mocker) -> None: + def test_transaction_closes_connection( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test the connection is closed after the transaction completes.""" # Mock the connection object itself mock_conn = mocker.patch.object(db_mock, "conn", autospec=True) @@ -36,7 +42,9 @@ def test_transaction_closes_connection(self, db_mock, mocker) -> None: # Ensure the connection is closed mock_conn.close.assert_called_once() - def test_transaction_rollback_on_exception(self, db_mock, mocker) -> None: + def test_transaction_rollback_on_exception( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that the transaction rolls back when an exception occurs.""" # Mock the connection object and ensure it's set as db_mock.conn mock_conn = mocker.Mock() @@ -51,7 +59,7 @@ def test_transaction_rollback_on_exception(self, db_mock, mocker) -> None: mock_conn.rollback.assert_called_once() mock_conn.commit.assert_not_called() - def test_in_transaction_flag(self, db_mock) -> None: + def test_in_transaction_flag(self, db_mock: SqliterDB) -> None: """Test that _in_transaction is set/unset inside a transaction.""" assert not db_mock._in_transaction # Initially, it should be False @@ -62,7 +70,9 @@ def test_in_transaction_flag(self, db_mock) -> None: not db_mock._in_transaction ) # Should be False again after exiting the context - def test_rollback_resets_in_transaction_flag(self, db_mock, mocker) -> None: + def test_rollback_resets_in_transaction_flag( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that _in_transaction is reset after a rollback on exception.""" def test_transaction() -> None: @@ -80,7 +90,9 @@ def test_transaction() -> None: not db_mock._in_transaction ) # Should be reset to False after exception - def test_maybe_commit_skips_in_transaction(self, db_mock, mocker) -> None: + def test_maybe_commit_skips_in_transaction( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that maybe_commit does not commit when inside a transaction.""" mock_conn = mocker.Mock() mocker.patch.object(db_mock, "conn", mock_conn) @@ -92,7 +104,9 @@ def test_maybe_commit_skips_in_transaction(self, db_mock, mocker) -> None: db_mock._maybe_commit() mock_conn.commit.assert_called_once() - def test_commit_called_once_in_transaction(self, mocker, tmp_path) -> None: + def test_commit_called_once_in_transaction( + self, mocker: MockerFixture, tmp_path: Path + ) -> None: """Ensure data is committed at the end of a transaction.""" # Create a temporary database file db_file = tmp_path / "test.db" diff --git a/tests/test_dates.py b/tests/test_dates.py index 4331b414..7e051d0d 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -7,6 +7,7 @@ from sqliter.helpers import from_unix_timestamp, to_unix_timestamp from sqliter.model.model import BaseDBModel +from sqliter.sqliter import SqliterDB class TestDates: @@ -101,7 +102,9 @@ def test_from_unix_timestamp_invalid_type(self) -> None: with pytest.raises(TypeError): from_unix_timestamp(1697803200, str) - def test_date_fields_create_integer_columns(self, db_mock) -> None: + def test_date_fields_create_integer_columns( + self, db_mock: SqliterDB + ) -> None: """Test that date & datetime fields create INTEGER columns in SQLite.""" class DateModel(BaseDBModel): @@ -126,7 +129,7 @@ class Meta: assert columns["date_field"] == "INTEGER" assert columns["datetime_field"] == "INTEGER" - def test_date_field_roundtrip(self, db_mock) -> None: + def test_date_field_roundtrip(self, db_mock: SqliterDB) -> None: """Test that dates survive a round trip to and from the database.""" test_datetime = datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc) test_date = date(2024, 1, 1) @@ -154,7 +157,7 @@ class Meta: assert fetched.date_field == test_date assert fetched.datetime_field == test_datetime - def test_datetime_different_timezones(self, db_mock) -> None: + def test_datetime_different_timezones(self, db_mock: SqliterDB) -> None: """Test handling of datetimes in different timezones.""" class TimezoneModel(BaseDBModel): @@ -191,7 +194,7 @@ class Meta: assert fetched_1.dt_field.timestamp() == test_dt_plus_2.timestamp() assert fetched_2.dt_field.timestamp() == test_dt_minus_5.timestamp() - def test_date_edge_cases(self, db_mock) -> None: + def test_date_edge_cases(self, db_mock: SqliterDB) -> None: """Test dates near Unix timestamp boundaries.""" class EdgeDateModel(BaseDBModel): @@ -227,7 +230,7 @@ class Meta: assert fetched is not None assert fetched.dt_field.timestamp() == original_dt.timestamp() - def test_optional_date_fields(self, db_mock) -> None: + def test_optional_date_fields(self, db_mock: SqliterDB) -> None: """Test handling of Optional[date] and Optional[datetime] fields.""" class OptionalDateModel(BaseDBModel): @@ -262,12 +265,13 @@ class Meta: fetched_value = db_mock.get(OptionalDateModel, inserted_value.pk) assert fetched_value is not None assert fetched_value.date_field == date(2024, 1, 1) + assert fetched_value.dt_field is not None assert ( fetched_value.dt_field.timestamp() == datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc).timestamp() ) - def test_update_date_fields(self, db_mock) -> None: + def test_update_date_fields(self, db_mock: SqliterDB) -> None: """Test updating date and datetime fields.""" class UpdateDateModel(BaseDBModel): diff --git a/tests/test_debug_logging.py b/tests/test_debug_logging.py index abb30429..ba7878b2 100644 --- a/tests/test_debug_logging.py +++ b/tests/test_debug_logging.py @@ -2,6 +2,9 @@ import logging +import pytest +from pytest_mock import MockerFixture + from sqliter.sqliter import SqliterDB from tests.conftest import ComplexModel @@ -29,7 +32,7 @@ def test_sqliterdb_debug_set_true(self) -> None: ) def test_debug_sql_output_basic_query( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test the debug output correctly prints the SQL query and values.""" with caplog.at_level(logging.DEBUG): @@ -46,7 +49,7 @@ def test_debug_sql_output_basic_query( ) def test_debug_sql_output_string_values( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that the debug output correctly handles string values.""" with caplog.at_level(logging.DEBUG): @@ -63,7 +66,7 @@ def test_debug_sql_output_string_values( ) def test_debug_sql_output_multiple_conditions( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that the debug output works with multiple conditions.""" with caplog.at_level(logging.DEBUG): @@ -80,7 +83,7 @@ def test_debug_sql_output_multiple_conditions( ) def test_debug_sql_output_order_and_limit( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that the debug output works with order and limit.""" with caplog.at_level(logging.DEBUG): @@ -97,7 +100,7 @@ def test_debug_sql_output_order_and_limit( ) def test_debug_sql_output_with_null_value( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that the debug output works when filtering on a NULL value.""" with caplog.at_level(logging.DEBUG): @@ -125,7 +128,7 @@ def test_debug_sql_output_with_null_value( ) def test_debug_sql_output_with_fields_single( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test debug output correct when selecting a single field.""" with caplog.at_level(logging.DEBUG): @@ -140,7 +143,7 @@ def test_debug_sql_output_with_fields_single( ) def test_debug_sql_output_with_fields_multiple( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that the debug output correct when selecting multiple fields.""" with caplog.at_level(logging.DEBUG): @@ -155,7 +158,7 @@ def test_debug_sql_output_with_fields_multiple( ) def test_debug_sql_output_with_fields_and_filter( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test the debug output correct with selected fields and a filter.""" with caplog.at_level(logging.DEBUG): @@ -169,7 +172,9 @@ def test_debug_sql_output_with_fields_and_filter( 'WHERE "score" > 85' in caplog.text ) - def test_no_log_output_when_debug_false(self, caplog) -> None: + def test_no_log_output_when_debug_false( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Test that no log output occurs when debug=False.""" db = SqliterDB(":memory:", debug=False) db.create_table(ComplexModel) @@ -181,7 +186,7 @@ def test_no_log_output_when_debug_false(self, caplog) -> None: assert caplog.text == "" def test_no_log_output_above_debug_level( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test no DEBUG log output occurs when log level is above DEBUG.""" with caplog.at_level(logging.INFO): # Set log level higher than DEBUG @@ -192,7 +197,9 @@ def test_no_log_output_above_debug_level( # Assert that no DEBUG messages are present in the logs assert caplog.text == "" - def test_manual_logger_respects_debug_flag(self, caplog) -> None: + def test_manual_logger_respects_debug_flag( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Test that a manually passed logger respects the debug flag.""" custom_logger = logging.getLogger("CustomLogger") custom_logger.setLevel(logging.DEBUG) @@ -208,7 +215,9 @@ def test_manual_logger_respects_debug_flag(self, caplog) -> None: '"age", "is_active", "score", ' in caplog.text ) - def test_manual_logger_above_debug_level(self, caplog) -> None: + def test_manual_logger_above_debug_level( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Ensure no log output when manually passed logger is above DEBUG.""" custom_logger = logging.getLogger("CustomLogger") custom_logger.setLevel(logging.INFO) # Set log level higher than DEBUG @@ -223,7 +232,7 @@ def test_manual_logger_above_debug_level(self, caplog) -> None: assert caplog.text == "" def test_debug_sql_output_no_matching_records( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test the debug output occurs even when no records match the query.""" with caplog.at_level(logging.DEBUG): @@ -240,7 +249,7 @@ def test_debug_sql_output_no_matching_records( ) def test_debug_sql_output_empty_query( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test debug output occurs for empty query (no filters, etc).""" with caplog.at_level(logging.DEBUG): @@ -254,7 +263,7 @@ def test_debug_sql_output_empty_query( ) def test_debug_output_drop_table( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test debug output when dropping a table.""" with caplog.at_level(logging.DEBUG): @@ -265,7 +274,9 @@ def test_debug_output_drop_table( "Executing SQL: DROP TABLE IF EXISTS complex_model" in caplog.text ) - def test_reset_database_debug_logging(self, temp_db_path, caplog) -> None: + def test_reset_database_debug_logging( + self, temp_db_path: str, caplog: pytest.LogCaptureFixture + ) -> None: """Test that resetting the database logs debug information.""" with caplog.at_level(logging.DEBUG): SqliterDB(temp_db_path, reset=True, debug=True) @@ -273,7 +284,7 @@ def test_reset_database_debug_logging(self, temp_db_path, caplog) -> None: assert "Database reset: 0 user-created tables dropped." in caplog.text def test_debug_output_insert( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that insert operations produce debug log output.""" with caplog.at_level(logging.DEBUG): @@ -291,7 +302,7 @@ def test_debug_output_insert( assert "INSERT INTO complex_model" in caplog.text def test_debug_output_get( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that get operations produce debug log output.""" with caplog.at_level(logging.DEBUG): @@ -302,7 +313,7 @@ def test_debug_output_get( assert "complex_model" in caplog.text def test_debug_output_update( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that update operations produce debug log output.""" record = db_mock_complex_debug.get(ComplexModel, 1) @@ -316,7 +327,7 @@ def test_debug_output_update( assert "UPDATE complex_model" in caplog.text def test_debug_output_delete( - self, db_mock_complex_debug: SqliterDB, caplog + self, db_mock_complex_debug: SqliterDB, caplog: pytest.LogCaptureFixture ) -> None: """Test that delete operations produce debug log output.""" with caplog.at_level(logging.DEBUG): @@ -325,7 +336,9 @@ def test_debug_output_delete( assert "Executing SQL:" in caplog.text assert "DELETE FROM complex_model" in caplog.text - def test_debug_output_table_names(self, caplog) -> None: + def test_debug_output_table_names( + self, caplog: pytest.LogCaptureFixture + ) -> None: """Test that table_names property produces debug log output.""" db = SqliterDB(":memory:", debug=True) db.create_table(ComplexModel) @@ -336,10 +349,10 @@ def test_debug_output_table_names(self, caplog) -> None: assert "Executing SQL:" in caplog.text assert "sqlite_master" in caplog.text - def test_setup_logger_else_clause(self, mocker) -> None: + def test_setup_logger_else_clause(self, mocker: MockerFixture) -> None: """Test the else clause configuration for the logger setup.""" # Mock the root logger's hasHandlers BEFORE creating the instance - mocker.patch.object( + has_handlers_mock = mocker.patch.object( logging.getLogger(), "hasHandlers", return_value=False ) @@ -360,7 +373,7 @@ def test_setup_logger_else_clause(self, mocker) -> None: assert logger.level == logging.DEBUG assert logger.propagate is False - logging.getLogger().hasHandlers.assert_called_once() + has_handlers_mock.assert_called_once() # Cleanup - crucial to prevent test pollution for hdlr in logger.handlers[:]: diff --git a/tests/test_drop_table.py b/tests/test_drop_table.py index 588c9e2e..84cd29d5 100644 --- a/tests/test_drop_table.py +++ b/tests/test_drop_table.py @@ -3,6 +3,7 @@ import sqlite3 import pytest +from pytest_mock import MockerFixture from sqliter import SqliterDB from sqliter.exceptions import TableDeletionError @@ -12,7 +13,7 @@ class TestDropTable: """Test class for the 'drop_table' method.""" - def test_drop_existing_table(self, db_mock) -> None: + def test_drop_existing_table(self, db_mock: SqliterDB) -> None: """Test dropping an existing table.""" class TestModel(BaseDBModel): @@ -34,7 +35,7 @@ class Meta: result = cursor.fetchone() assert result is None - def test_drop_non_existent_table(self, db_mock) -> None: + def test_drop_non_existent_table(self, db_mock: SqliterDB) -> None: """Test dropping a table that doesn't exist.""" class NonExistentModel(BaseDBModel): @@ -46,7 +47,7 @@ class Meta: # This should not raise an exception due to 'IF EXISTS' in the SQL db_mock.drop_table(NonExistentModel) - def test_drop_table_with_data(self, db_mock) -> None: + def test_drop_table_with_data(self, db_mock: SqliterDB) -> None: """Test dropping a table that contains data.""" class DataModel(BaseDBModel): @@ -70,7 +71,9 @@ class Meta: result = cursor.fetchone() assert result is None - def test_drop_table_error(self, db_mock: SqliterDB, mocker) -> None: + def test_drop_table_error( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test error handling when dropping a table fails.""" class ErrorModel(BaseDBModel): @@ -88,7 +91,9 @@ class Meta: exc_info.value ) - def test_drop_table_auto_commit(self, db_mock, mocker) -> None: + def test_drop_table_auto_commit( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test auto-commit behavior when dropping a table.""" class CommitModel(BaseDBModel): diff --git a/tests/test_execeptions.py b/tests/test_execeptions.py index 0bb30096..a6b77b0d 100644 --- a/tests/test_execeptions.py +++ b/tests/test_execeptions.py @@ -3,6 +3,7 @@ import sqlite3 import pytest +from pytest_mock import MockerFixture from sqliter.exceptions import ( DatabaseConnectionError, @@ -38,7 +39,7 @@ def test_sqliter_error_without_template(self) -> None: assert str(exc) == "An error occurred in the SQLiter package." - def test_database_connection_error(self, mocker) -> None: + def test_database_connection_error(self, mocker: MockerFixture) -> None: """Test that DatabaseConnectionError is raised when connection fails.""" # Mock sqlite3.connect to raise an error mocker.patch("sqlite3.connect", side_effect=sqlite3.Error) @@ -56,7 +57,7 @@ def test_database_connection_error(self, mocker) -> None: ) # @pytest.mark.skip(reason="This is no longer a valid test case.") - def test_insert_duplicate_primary_key(self, db_mock) -> None: + def test_insert_duplicate_primary_key(self, db_mock: SqliterDB) -> None: """Test that exception raised when inserting duplicate primary key.""" # Create a model instance with a unique primary key example_model = ExampleModel( @@ -75,7 +76,9 @@ def test_insert_duplicate_primary_key(self, db_mock) -> None: exc_info.value ) - def test_create_table_error(self, db_mock, mocker) -> None: + def test_create_table_error( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test exception is raised when creating table with invalid model.""" # Mock sqlite3.connect to raise an error mocker.patch("sqliter.SqliterDB.connect", side_effect=sqlite3.Error) @@ -87,7 +90,7 @@ def test_create_table_error(self, db_mock, mocker) -> None: # Verify that the exception message contains the table name assert "Failed to create the table: 'test_table'" in str(exc_info.value) - def test_update_not_found_error(self, db_mock) -> None: + def test_update_not_found_error(self, db_mock: SqliterDB) -> None: """Test exception raised when updating a record that does not exist.""" # Create a model instance with a unique primary key example_model = ExampleModel( @@ -103,7 +106,9 @@ def test_update_not_found_error(self, db_mock) -> None: exc_info.value ) - def test_update_exception_error(self, db_mock, mocker) -> None: + def test_update_exception_error( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test an exception is raised when updating a record with an error.""" # Create a model instance with a unique primary key example_model = ExampleModel( @@ -125,7 +130,9 @@ def test_update_exception_error(self, db_mock, mocker) -> None: exc_info.value ) - def test_delete_exception_error(self, db_mock, mocker) -> None: + def test_delete_exception_error( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that exception raised when deleting a record with an error.""" # Create a model instance with a unique primary key example_model = ExampleModel( diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 6ad445a5..e8ad4678 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -5,6 +5,7 @@ import pytest from pydantic import Field +from pytest_mock import MockerFixture from sqliter import SqliterDB from sqliter.exceptions import ( @@ -179,7 +180,7 @@ def test_create_table_with_fk(self) -> None: assert "authors" in db.table_names assert "books" in db.table_names - def test_fk_constraint_sql_generation(self, mocker) -> None: + def test_fk_constraint_sql_generation(self, mocker: MockerFixture) -> None: """Test that FK constraint SQL is generated correctly.""" mock_cursor = mocker.MagicMock() mocker.patch.object( @@ -618,7 +619,9 @@ class TestBook(BaseDBModel): class TestGetForeignKeyInfoEdgeCases: """Test edge cases for get_foreign_key_info function.""" - def test_field_without_json_schema_extra_attribute(self, mocker) -> None: + def test_field_without_json_schema_extra_attribute( + self, mocker: MockerFixture + ) -> None: """Test get_foreign_key_info with field lacking json_schema_extra.""" # Create a mock FieldInfo without json_schema_extra attribute mock_field_info = mocker.MagicMock( @@ -648,7 +651,7 @@ class TestBook(BaseDBModel): class TestForeignKeyDatabaseErrors: """Test database error handling for FK operations.""" - def test_insert_general_database_error(self, mocker) -> None: + def test_insert_general_database_error(self, mocker: MockerFixture) -> None: """Test that general sqlite3.Error during insert raises properly.""" db = SqliterDB(":memory:") db.create_table(Author) @@ -669,7 +672,7 @@ def test_insert_general_database_error(self, mocker) -> None: with pytest.raises(RecordInsertionError): db.insert(Author(name="Test", email="test@example.com")) - def test_delete_non_fk_integrity_error(self, mocker) -> None: + def test_delete_non_fk_integrity_error(self, mocker: MockerFixture) -> None: """Test delete with IntegrityError that is not FK-related.""" db = SqliterDB(":memory:") db.create_table(Author) diff --git a/tests/test_foreign_keys_orm.py b/tests/test_foreign_keys_orm.py index b9430721..493df4a2 100644 --- a/tests/test_foreign_keys_orm.py +++ b/tests/test_foreign_keys_orm.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from collections.abc import Generator @@ -152,7 +152,7 @@ def mock_get(model_class: type, pk: int) -> object: # Clear any cached loader if hasattr(book, "_fk_cache"): - book._fk_cache.clear() + cast("dict[str, Any]", book._fk_cache).clear() # Accessing the author should raise AttributeError because # the LazyLoader returns None for the loaded object @@ -174,7 +174,7 @@ def test_reverse_relationship_fetch_all(self, db: SqliterDB) -> None: db.insert(Book(title="Book 2", author=author)) # Fetch books via reverse relationship - books = author.books.fetch_all() + books = cast("list[Book]", cast("Any", author.books).fetch_all()) assert len(books) == 2 assert {b.title for b in books} == {"Book 1", "Book 2"} @@ -189,7 +189,10 @@ def test_reverse_relationship_filter(self, db: SqliterDB) -> None: db.insert(Book(title="Book C", author=author)) # Filter books by title - books = author.books.filter(title__like="Book A%").fetch_all() + books = cast( + "list[Book]", + cast("Any", author.books).filter(title__like="Book A%").fetch_all(), + ) assert len(books) == 1 assert books[0].title == "Book A" @@ -204,7 +207,7 @@ def test_reverse_relationship_count(self, db: SqliterDB) -> None: db.insert(Book(title="Book 3", author=author)) # Count books - count = author.books.count() + count = cast("Any", author.books).count() assert count == 3 def test_reverse_relationship_exists(self, db: SqliterDB) -> None: @@ -213,10 +216,10 @@ def test_reverse_relationship_exists(self, db: SqliterDB) -> None: db.create_table(Book) author = db.insert(Author(name="Eve", email="eve@example.com")) - assert not author.books.exists() + assert not cast("Any", author.books).exists() db.insert(Book(title="Book 1", author=author)) - assert author.books.exists() + assert cast("Any", author.books).exists() def test_reverse_relationship_empty(self, db: SqliterDB) -> None: """Test reverse relationship with no related objects.""" @@ -224,7 +227,7 @@ def test_reverse_relationship_empty(self, db: SqliterDB) -> None: db.create_table(Book) author = db.insert(Author(name="Frank", email="frank@example.com")) - books = author.books.fetch_all() + books = cast("list[Book]", cast("Any", author.books).fetch_all()) assert books == [] def test_reverse_relationship_limit_offset(self, db: SqliterDB) -> None: @@ -239,11 +242,16 @@ def test_reverse_relationship_limit_offset(self, db: SqliterDB) -> None: db.insert(Book(title="Book 4", author=author)) # Limit - books = author.books.limit(2).fetch_all() + books = cast( + "list[Book]", cast("Any", author.books).limit(2).fetch_all() + ) assert len(books) == 2 # Offset - books = author.books.limit(2).offset(1).fetch_all() + books = cast( + "list[Book]", + cast("Any", author.books).limit(2).offset(1).fetch_all(), + ) assert len(books) == 2 def test_reverse_relationship_with_custom_related_name( @@ -268,7 +276,9 @@ class CustomBook(BaseDBModel): db.insert(CustomBook(title="Book 2", author=author)) # Use custom related name - books = author.publications.fetch_all() + books = cast( + "list[CustomBook]", cast("Any", author.publications).fetch_all() + ) assert len(books) == 2 @@ -354,7 +364,7 @@ def test_prefetch_reverse_fk_with_custom_fk_column( .fetch_all() ) counts = { - author.name: len(author.custom_books.fetch_all()) + author.name: len(cast("Any", author.custom_books).fetch_all()) for author in authors } assert counts == {"Alice": 2, "Bob": 1} @@ -498,7 +508,7 @@ def test_cascade_delete_with_reverse_relationships( db.insert(Book(title="Book 2", author=author)) # Get books via reverse relationship - books = author.books.fetch_all() + books = cast("list[Book]", cast("Any", author.books).fetch_all()) assert len(books) == 2 # Delete author - should cascade delete books @@ -552,7 +562,7 @@ def test_repr_unloaded(self, db: SqliterDB) -> None: # Clear cache to get fresh LazyLoader if hasattr(book, "_fk_cache"): - book._fk_cache.clear() + cast("dict[str, Any]", book._fk_cache).clear() # Get the LazyLoader without triggering load lazy = book.__dict__.get("_fk_cache", {}).get("author") @@ -586,7 +596,7 @@ def test_repr_loaded(self, db: SqliterDB) -> None: _ = book.author.name # Get the cached LazyLoader - lazy = book._fk_cache.get("author") + lazy = cast("dict[str, Any]", book._fk_cache).get("author") assert lazy is not None # LazyLoader exists in cache repr_str = repr(lazy) @@ -797,7 +807,7 @@ def test_fetch_one(self, db: SqliterDB) -> None: db.insert(Book(title="Book 2", author=author)) # fetch_one should return a single book - book = author.books.fetch_one() + book = cast("Any", author.books).fetch_one() assert book is not None assert book.title in {"Book 1", "Book 2"} @@ -809,7 +819,7 @@ def test_fetch_one_empty(self, db: SqliterDB) -> None: author = db.insert(Author(name="NoBooks", email="nb@example.com")) # fetch_one should return None - book = author.books.fetch_one() + book = cast("Any", author.books).fetch_one() assert book is None def test_fetch_all_no_db_context(self) -> None: @@ -818,7 +828,7 @@ def test_fetch_all_no_db_context(self) -> None: author = Author(name="NoContext", email="nc@example.com") # Should return empty list - books = author.books.fetch_all() + books = cast("list[Book]", cast("Any", author.books).fetch_all()) assert books == [] def test_count_no_db_context(self) -> None: @@ -827,7 +837,7 @@ def test_count_no_db_context(self) -> None: author = Author(name="NoContext", email="nc@example.com") # Should return 0 - count = author.books.count() + count = cast("Any", author.books).count() assert count == 0 def test_count_with_filters(self, db: SqliterDB) -> None: @@ -841,7 +851,7 @@ def test_count_with_filters(self, db: SqliterDB) -> None: db.insert(Book(title="Python Guide", author=author)) # Count with filter - count = author.books.filter(title__like="Python%").count() + count = cast("Any", author.books).filter(title__like="Python%").count() assert count == 2 @@ -851,7 +861,8 @@ class TestReverseRelationshipDescriptor: def test_class_level_access(self) -> None: """Test accessing reverse relationship on class returns descriptor.""" # Access on class, not instance - descriptor = Author.books + reverse_attr = "books" + descriptor = cast("Any", getattr(Author, reverse_attr)) assert isinstance(descriptor, ReverseRelationship) def test_cannot_set_reverse_relationship(self, db: SqliterDB) -> None: @@ -865,10 +876,11 @@ def test_cannot_set_reverse_relationship(self, db: SqliterDB) -> None: author = db.insert(Author(name="NoSet", email="ns@example.com")) # Call descriptor directly to bypass Pydantic's __setattr__ + reverse_attr = "books" with pytest.raises( AttributeError, match="Cannot set reverse relationship" ): - Author.books.__set__(author, []) + cast("Any", getattr(Author, reverse_attr)).__set__(author, []) class TestRegistryPendingRelationships: diff --git a/tests/test_m2m.py b/tests/test_m2m.py index 6593697a..6a0698f4 100644 --- a/tests/test_m2m.py +++ b/tests/test_m2m.py @@ -231,7 +231,8 @@ class TargetResolved(BaseDBModel): def test_reverse_descriptor_metadata(self) -> None: """Reverse descriptor exposes SQL metadata in reverse orientation.""" - desc = Tag.articles + reverse_attr = "articles" + desc = cast("Any", getattr(Tag, reverse_attr)) metadata = desc.sql_metadata assert isinstance(metadata, M2MSQLMetadata) assert metadata.junction_table == "articles_tags" @@ -255,7 +256,7 @@ def test_manager_metadata_forward(self, db: SqliterDB) -> None: def test_manager_metadata_reverse(self, db: SqliterDB) -> None: """Reverse manager metadata matches reverse accessor orientation.""" tag = db.insert(Tag(name="python")) - metadata = tag.articles.sql_metadata + metadata = cast("Any", tag.articles).sql_metadata assert metadata.junction_table == "articles_tags" assert metadata.from_column == "tags_pk" assert metadata.to_column == "articles_pk" @@ -303,7 +304,8 @@ class PersonMeta(BaseDBModel): related_name="followers", ) - desc = PersonMeta.followers + reverse_attr = "followers" + desc = cast("Any", getattr(PersonMeta, reverse_attr)) metadata = desc.sql_metadata person_table = PersonMeta.get_table_name() assert metadata.from_column == f"{person_table}_pk_right" @@ -333,7 +335,8 @@ def test_forward_descriptor_metadata_is_cached(self) -> None: def test_reverse_descriptor_metadata_is_cached(self) -> None: """Reverse descriptor metadata should be memoized.""" - desc = Tag.articles + reverse_attr = "articles" + desc = cast("Any", getattr(Tag, reverse_attr)) first = desc.sql_metadata second = desc.sql_metadata assert first is second @@ -761,7 +764,7 @@ def test_reverse_fetch_all(self, db: SqliterDB) -> None: article1.tags.add(tag) article2.tags.add(tag) - articles = tag.articles.fetch_all() + articles = cast("Any", tag.articles).fetch_all() assert len(articles) == 2 titles = {a.title for a in articles} assert titles == {"Guide 1", "Guide 2"} @@ -771,8 +774,8 @@ def test_reverse_add(self, db: SqliterDB) -> None: article = db.insert(Article(title="Guide")) tag = db.insert(Tag(name="python")) - tag.articles.add(article) - assert tag.articles.count() == 1 + cast("Any", tag.articles).add(article) + assert cast("Any", tag.articles).count() == 1 # Also visible from the forward side assert article.tags.count() == 1 @@ -782,9 +785,9 @@ def test_reverse_remove(self, db: SqliterDB) -> None: tag = db.insert(Tag(name="python")) article.tags.add(tag) - tag.articles.remove(article) + cast("Any", tag.articles).remove(article) - assert tag.articles.count() == 0 + assert cast("Any", tag.articles).count() == 0 assert article.tags.count() == 0 def test_reverse_clear(self, db: SqliterDB) -> None: @@ -795,9 +798,9 @@ def test_reverse_clear(self, db: SqliterDB) -> None: article1.tags.add(tag) article2.tags.add(tag) - tag.articles.clear() + cast("Any", tag.articles).clear() - assert tag.articles.count() == 0 + assert cast("Any", tag.articles).count() == 0 def test_reverse_count(self, db: SqliterDB) -> None: """Reverse count() works correctly.""" @@ -805,7 +808,7 @@ def test_reverse_count(self, db: SqliterDB) -> None: tag = db.insert(Tag(name="python")) article.tags.add(tag) - assert tag.articles.count() == 1 + assert cast("Any", tag.articles).count() == 1 def test_reverse_cannot_set(self, db: SqliterDB) -> None: """Direct assignment to reverse M2M raises AttributeError.""" @@ -814,12 +817,12 @@ def test_reverse_cannot_set(self, db: SqliterDB) -> None: tag.articles = [] def test_reverse_set_handler_allows_noop( - self, db: SqliterDB, monkeypatch + self, db: SqliterDB, monkeypatch: pytest.MonkeyPatch ) -> None: """Reverse M2M handler returns True when __set__ does not raise.""" tag = db.insert(Tag(name="python")) - def noop_set(self, instance: object, value: object) -> None: + def noop_set(self: object, instance: object, value: object) -> None: return None monkeypatch.setattr(ReverseManyToMany, "__set__", noop_set) @@ -908,7 +911,8 @@ def test_m2m_registered(self) -> None: def test_reverse_accessor_on_target(self) -> None: """Reverse accessor descriptor is on target model.""" - desc = Tag.articles + reverse_attr = "articles" + desc = cast("Any", getattr(Tag, reverse_attr)) assert isinstance(desc, ReverseManyToMany) def test_m2m_descriptors_classvar(self) -> None: @@ -1000,7 +1004,7 @@ def test_custom_through_operations(self, db_custom: SqliterDB) -> None: assert cats[0].label == "Tech" # Reverse - assert cat.posts.count() == 1 + assert cast("Any", cat.posts).count() == 1 # โ”€โ”€ TestManyToManyEdgeCases โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -1120,7 +1124,9 @@ class MissingTarget(BaseDBModel): finally: ModelRegistry.restore(state) - def test_inflect_import_error_fallback(self, monkeypatch) -> None: + def test_inflect_import_error_fallback( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """Fallback related_name used when inflect import fails.""" original_import = builtins.__import__ @@ -1193,7 +1199,9 @@ class BadThrough(BaseDBModel): # pylint: disable=unused-variable finally: ModelRegistry.restore(state) - def test_create_m2m_junction_tables_import_error(self, monkeypatch) -> None: + def test_create_m2m_junction_tables_import_error( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: """ImportError in M2M setup is ignored.""" original_import = builtins.__import__ @@ -1298,7 +1306,9 @@ class Meta: u2 = db.insert(User(name="U2")) u1.follows.add(u2) - assert {u.name for u in u2.followed_by.fetch_all()} == {"U1"} + assert { + u.name for u in cast("Any", u2.followed_by).fetch_all() + } == {"U1"} assert u2.follows.count() == 0 finally: ModelRegistry.restore(state) @@ -1413,7 +1423,8 @@ def test_reverse_accessor_conflict_raises(self) -> None: class TargetConflict(BaseDBModel): name: str - TargetConflict.conflict_attr = object() + conflict_attr = "conflict_attr" + setattr(TargetConflict, conflict_attr, object()) # Python 3.10 wraps __set_name__ errors in RuntimeError. with pytest.raises( @@ -1460,7 +1471,7 @@ class Target(BaseDBModel): ModelRegistry.restore(state) def test_junction_table_resolution_failure_raises( - self, monkeypatch + self, monkeypatch: pytest.MonkeyPatch ) -> None: """If junction table can't be resolved, registration fails.""" state = ModelRegistry.snapshot() @@ -1471,7 +1482,7 @@ def test_junction_table_resolution_failure_raises( class TargetBad(BaseDBModel): name: str - def bad_junction(self, owner: type[BaseDBModel]) -> None: + def bad_junction(self: object, owner: type[BaseDBModel]) -> None: _ = self _ = owner @@ -1564,7 +1575,8 @@ class AuthorConflict(BaseDBModel): class BookConflict(BaseDBModel): title: str - AuthorConflict.conflict = object() + conflict_attr = "conflict" + setattr(AuthorConflict, conflict_attr, object()) ModelRegistry.register_model(AuthorConflict) with pytest.raises(AttributeError, match="Reverse relationship"): diff --git a/tests/test_optional_fields.py b/tests/test_optional_fields.py index 43334b8a..10b66c38 100644 --- a/tests/test_optional_fields.py +++ b/tests/test_optional_fields.py @@ -265,7 +265,7 @@ def test_fields_operator_no_fields_explicitly( for field in all_fields: assert hasattr(result, field) - def test_validate_fields_with_none(self, db_mock_adv) -> None: + def test_validate_fields_with_none(self, db_mock_adv: SqliterDB) -> None: """Test _validate_fields with self._fields set to None.""" # This test will indirectly invoke _validate_fields by creating a # QueryBuilder without specifying fields (i.e., self._fields will be @@ -276,7 +276,9 @@ def test_validate_fields_with_none(self, db_mock_adv) -> None: # since self._fields is None. assert query._fields is None - def test_direct_validate_fields_with_none(self, db_mock_adv) -> None: + def test_direct_validate_fields_with_none( + self, db_mock_adv: SqliterDB + ) -> None: """Test _validate_fields directly with self._fields set to None.""" # Create the query builder instance query = db_mock_adv.select(PersonModel, fields=None) diff --git a/tests/test_order_method.py b/tests/test_order_method.py index 5cdbb666..11104933 100644 --- a/tests/test_order_method.py +++ b/tests/test_order_method.py @@ -4,12 +4,13 @@ from sqliter.exceptions import InvalidOrderError from sqliter.model import BaseDBModel +from sqliter.sqliter import SqliterDB class TestOrderMethod: """Test class for the 'order' method in the QueryBuilder class.""" - def test_order_by_primary_key_default(self, db_mock) -> None: + def test_order_by_primary_key_default(self, db_mock: SqliterDB) -> None: """Test ordering by primary key when no field is specified.""" class OrderTestModel(BaseDBModel): @@ -31,7 +32,7 @@ class Meta: assert results[1].id == 2 assert results[2].id == 3 - def test_order_by_primary_key_reverse(self, db_mock) -> None: + def test_order_by_primary_key_reverse(self, db_mock: SqliterDB) -> None: """Test ordering by primary key in descending order.""" class OrderTestModel(BaseDBModel): @@ -53,7 +54,7 @@ class Meta: assert results[1].id == 2 assert results[2].id == 1 - def test_order_by_specified_field(self, db_mock) -> None: + def test_order_by_specified_field(self, db_mock: SqliterDB) -> None: """Test ordering by a specified field.""" class OrderTestModel(BaseDBModel): @@ -75,7 +76,7 @@ class Meta: assert results[1].name == "Bob" assert results[2].name == "Charlie" - def test_order_by_specified_field_reverse(self, db_mock) -> None: + def test_order_by_specified_field_reverse(self, db_mock: SqliterDB) -> None: """Test ordering by a specified field in descending order.""" class OrderTestModel(BaseDBModel): @@ -101,7 +102,7 @@ class Meta: assert results[1].name == "Bob" assert results[2].name == "Alice" - def test_order_with_reverse_false(self, db_mock) -> None: + def test_order_with_reverse_false(self, db_mock: SqliterDB) -> None: """Test the order method works with reverse=False (ascending order).""" class TestModel(BaseDBModel): @@ -124,7 +125,7 @@ class Meta: assert results[1].name == "Bob" assert results[2].name == "Charlie" - def test_order_invalid_field(self, db_mock) -> None: + def test_order_invalid_field(self, db_mock: SqliterDB) -> None: """Test ordering by an invalid field.""" class OrderTestModel(BaseDBModel): @@ -143,7 +144,7 @@ class Meta: exc.value ) - def test_order_both_direction_and_reverse(self, db_mock) -> None: + def test_order_both_direction_and_reverse(self, db_mock: SqliterDB) -> None: """Test ordering with both direction and reverse specified.""" class OrderTestModel(BaseDBModel): @@ -163,7 +164,7 @@ class Meta: "name", direction="ASC", reverse=True ).fetch_all() - def test_order_deprecation_warning(self, db_mock) -> None: + def test_order_deprecation_warning(self, db_mock: SqliterDB) -> None: """Test that using 'direction' raises a DeprecationWarning.""" class TestModel(BaseDBModel): @@ -174,7 +175,7 @@ class TestModel(BaseDBModel): ): db_mock.select(TestModel).order("name", direction="ASC") - def test_order_invalid_direction(self, db_mock) -> None: + def test_order_invalid_direction(self, db_mock: SqliterDB) -> None: """Test that an invalid order direction raises an exception.""" # Define a simple model for the test @@ -193,7 +194,7 @@ class TestModel(BaseDBModel): in str(exc.value) ) - def test_order_direction_ascending(self, db_mock) -> None: + def test_order_direction_ascending(self, db_mock: SqliterDB) -> None: """Test that the order method works as expected when ASC specified.""" # Define a simple model for the test @@ -222,7 +223,7 @@ class Meta: assert results[1].name == "Jim Doe" assert results[2].name == "John Doe" - def test_order_direction_desc(self, db_mock) -> None: + def test_order_direction_desc(self, db_mock: SqliterDB) -> None: """Test that the order method works as expected descending.""" # Define a simple model for the test diff --git a/tests/test_orm_fields.py b/tests/test_orm_fields.py index 5b74268f..d74f4cef 100644 --- a/tests/test_orm_fields.py +++ b/tests/test_orm_fields.py @@ -3,8 +3,7 @@ from __future__ import annotations import sys -import types -from typing import Optional +from typing import Any, Optional, cast import pytest from pydantic.fields import FieldInfo @@ -199,7 +198,7 @@ class PEP604Related(BaseDBModel): "PEP604Related | None", {"PEP604Related": PEP604Related}, ) - assert isinstance(pep604_union, types.UnionType) + assert type(pep604_union).__name__ in {"UnionType", "Union"} # Build an owner class and inject the PEP 604 annotation # Use __class_getitem__ to avoid mypy treating the variable as @@ -209,9 +208,9 @@ class OwnerPEP604(BaseDBModel): name: str - OwnerPEP604.__annotations__["rel"] = ForeignKey.__class_getitem__( - pep604_union - ) + OwnerPEP604.__annotations__["rel"] = cast( + "Any", ForeignKey + ).__class_getitem__(pep604_union) fk = ForeignKey(PEP604Related) fk._detect_nullable_from_annotation(OwnerPEP604, "rel") diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index adbd430b..a982bbeb 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -180,7 +180,7 @@ def test_access_prefetched_data(self, db: SqliterDB) -> None: jane = next(a for a in authors if a.name == "Jane Austen") result = jane.books assert isinstance(result, PrefetchedResult) - books = result.fetch_all() + books = cast("list[Book]", result.fetch_all()) assert len(books) == 2 titles = {b.title for b in books} assert "Pride and Prejudice" in titles @@ -190,33 +190,35 @@ def test_prefetched_count_and_exists(self, db: SqliterDB) -> None: authors = db.select(Author).prefetch_related("books").fetch_all() jane = next(a for a in authors if a.name == "Jane Austen") - assert jane.books.count() == 2 - assert jane.books.exists() is True + prefetched_books = cast("Any", jane.books) + assert prefetched_books.count() == 2 + assert prefetched_books.exists() is True nobooks = next(a for a in authors if a.name == "No Books Author") - assert nobooks.books.count() == 0 - assert nobooks.books.exists() is False + nobooks_books = cast("Any", nobooks.books) + assert nobooks_books.count() == 0 + assert nobooks_books.exists() is False def test_prefetched_fetch_one(self, db: SqliterDB) -> None: """Prefetched .fetch_one() returns first or None.""" authors = db.select(Author).prefetch_related("books").fetch_all() jane = next(a for a in authors if a.name == "Jane Austen") - book = jane.books.fetch_one() + book = cast("Any", jane.books).fetch_one() assert book is not None nobooks = next(a for a in authors if a.name == "No Books Author") - assert nobooks.books.fetch_one() is None + assert cast("Any", nobooks.books).fetch_one() is None def test_prefetched_filter_falls_back_to_db(self, db: SqliterDB) -> None: """Prefetched .filter() falls back to a DB query.""" authors = db.select(Author).prefetch_related("books").fetch_all() jane = next(a for a in authors if a.name == "Jane Austen") - result = jane.books.filter(year__gt=1812) + result = cast("Any", jane.books).filter(year__gt=1812) # filter() returns a ReverseQuery (falls back to DB) assert isinstance(result, ReverseQuery) - filtered = result.fetch_all() + filtered = cast("list[Book]", result.fetch_all()) assert len(filtered) == 1 assert filtered[0].title == "Pride and Prejudice" @@ -227,8 +229,8 @@ def test_multiple_prefetch_paths(self, db: SqliterDB) -> None: ) jane = next(a for a in authors if a.name == "Jane Austen") - assert jane.books.count() == 2 - assert jane.reviews.count() == 2 + assert cast("Any", jane.books).count() == 2 + assert cast("Any", jane.reviews).count() == 2 def test_no_related_objects_get_empty_list(self, db: SqliterDB) -> None: """Instances with no related objects get [] in cache.""" @@ -248,7 +250,7 @@ def test_combined_with_filter(self, db: SqliterDB) -> None: ) assert len(authors) == 1 - assert authors[0].books.count() == 2 + assert cast("Any", authors[0].books).count() == 2 def test_combined_with_order_and_limit(self, db: SqliterDB) -> None: """Chaining with order and limit works.""" @@ -322,7 +324,7 @@ def test_m2m_reverse_prefetch(self, db: SqliterDB) -> None: assert len(articles) == 2 unused_tag = next(t for t in tags if t.name == "unused") - assert unused_tag.articles.count() == 0 + assert cast("Any", unused_tag.articles).count() == 0 def test_m2m_prefetched_count_exists(self, db: SqliterDB) -> None: """Prefetched M2M .count() and .exists() work.""" @@ -685,7 +687,7 @@ def test_two_level_reverse_fk_then_m2m(self, nested_db: SqliterDB) -> None: ) jane = next(a for a in authors if a.name == "Jane Austen") - books = jane.authored_books.fetch_all() + books = cast("Any", jane.authored_books).fetch_all() assert len(books) == 2 pride = next(b for b in books if b.title == "Pride and Prejudice") @@ -705,7 +707,7 @@ def test_two_level_reverse_m2m_then_reverse_fk( ) python_tag = next(t for t in tags if t.name == "python") - articles = python_tag.articles.fetch_all() + articles = cast("Any", python_tag.articles).fetch_all() assert len(articles) == 2 guide = next(a for a in articles if a.title == "SQLiter Guide") @@ -725,11 +727,11 @@ def test_mixed_nested_and_flat_paths(self, nested_db: SqliterDB) -> None: jane = next(a for a in authors if a.name == "Jane Austen") # Flat path - reviews = jane.reviews.fetch_all() + reviews = cast("Any", jane.reviews).fetch_all() assert len(reviews) == 1 # Nested path - books = jane.authored_books.fetch_all() + books = cast("Any", jane.authored_books).fetch_all() pride = next(b for b in books if b.title == "Pride and Prejudice") cats = pride.categories.fetch_all() assert len(cats) == 2 @@ -743,7 +745,7 @@ def test_overlapping_paths_deduplicated(self, nested_db: SqliterDB) -> None: ) jane = next(a for a in authors if a.name == "Jane Austen") - books = jane.authored_books.fetch_all() + books = cast("Any", jane.authored_books).fetch_all() assert len(books) == 2 pride = next(b for b in books if b.title == "Pride and Prejudice") @@ -760,7 +762,7 @@ def test_nested_prefetch_with_filter(self, nested_db: SqliterDB) -> None: ) assert len(authors) == 1 - books = authors[0].authored_books.fetch_all() + books = cast("Any", authors[0].authored_books).fetch_all() assert len(books) == 2 def test_nested_prefetch_with_fetch_one(self, nested_db: SqliterDB) -> None: @@ -773,7 +775,7 @@ def test_nested_prefetch_with_fetch_one(self, nested_db: SqliterDB) -> None: ) assert author is not None - books = author.authored_books.fetch_all() + books = cast("Any", author.authored_books).fetch_all() assert len(books) == 2 pride = next(b for b in books if b.title == "Pride and Prejudice") @@ -789,7 +791,7 @@ def test_empty_intermediate_level(self, nested_db: SqliterDB) -> None: ) nobooks = next(a for a in authors if a.name == "No Books Author") - books = nobooks.authored_books.fetch_all() + books = cast("Any", nobooks.authored_books).fetch_all() assert books == [] def test_cache_populated_at_each_level(self, nested_db: SqliterDB) -> None: @@ -849,11 +851,16 @@ def test_unresolved_reverse_m2m_forward_ref(self) -> None: class LocalHost(BaseDBModel): name: str - LocalHost.ghosts = ReverseManyToMany( - from_model=cast("type[Any]", "GhostModel"), - to_model=LocalHost, - junction_table="ghost_host", - related_name="ghosts", + ghosts_attr = "ghosts" + setattr( + LocalHost, + ghosts_attr, + ReverseManyToMany( + from_model=cast("type[Any]", "GhostModel"), + to_model=LocalHost, + junction_table="ghost_host", + related_name="ghosts", + ), ) db = SqliterDB(":memory:") @@ -873,7 +880,7 @@ def test_three_levels_deep(self, nested_db: SqliterDB) -> None: ) python_tag = next(t for t in tags if t.name == "python") - articles = python_tag.articles.fetch_all() + articles = cast("Any", python_tag.articles).fetch_all() guide = next(a for a in articles if a.title == "SQLiter Guide") comments = guide.comments.fetch_all() assert len(comments) == 2 @@ -921,7 +928,7 @@ def test_all_parents_lack_books_nested(self, nested_db: SqliterDB) -> None: assert len(authors) == 1 nobooks = authors[0] - assert nobooks.authored_books.fetch_all() == [] + assert cast("Any", nobooks.authored_books).fetch_all() == [] def test_prefetch_segment_skips_pkless_instances( self, nested_db: SqliterDB diff --git a/tests/test_properties.py b/tests/test_properties.py index 35ce662c..4b0dafb6 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -3,6 +3,7 @@ import tempfile import pytest +from pytest_mock import MockerFixture from sqliter.exceptions import DatabaseConnectionError from sqliter.model.model import BaseDBModel @@ -131,7 +132,9 @@ class Meta: f"Expected 'test_table', got {table_names}" ) - def test_table_names_connection_failure(self, mocker) -> None: + def test_table_names_connection_failure( + self, mocker: MockerFixture + ) -> None: """Test 'table_names' raises exception if the connection fails.""" # Create an instance of the database db = SqliterDB(memory=True) diff --git a/tests/test_query.py b/tests/test_query.py index d4002efe..611d0baf 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -3,6 +3,7 @@ from typing import Optional import pytest +from pytest_mock import MockerFixture from sqliter.exceptions import ( InvalidFilterError, @@ -18,7 +19,7 @@ class TestQuery: """Test cases for the QueryBuilder class.""" - def test_fetch_all_no_results(self, db_mock) -> None: + def test_fetch_all_no_results(self, db_mock: SqliterDB) -> None: """Test that fetch_all returns None when no results are found.""" # Define a simple model for the test @@ -37,7 +38,7 @@ class Meta: # Assert that fetch_all returns None when no results are found assert result == [] - def test_fetch_one_single_result(self, db_mock) -> None: + def test_fetch_one_single_result(self, db_mock: SqliterDB) -> None: """Test that fetch_one returns a single result as a model instance.""" # Define a simple model for the test @@ -60,7 +61,7 @@ class Meta: assert result is not None assert result.name == "John Doe" - def test_fetch_one_no_results(self, db_mock) -> None: + def test_fetch_one_no_results(self, db_mock: SqliterDB) -> None: """Test that fetch_one returns None when no results are found.""" # Define a simple model for the test @@ -79,7 +80,7 @@ class Meta: # Assert that fetch_one returns None when no results are found assert result is None - def test_fetch_first_single_result(self, db_mock) -> None: + def test_fetch_first_single_result(self, db_mock: SqliterDB) -> None: """Test that fetch_first returns a single result as a model instance.""" # Define a simple model for the test @@ -102,7 +103,7 @@ class Meta: assert result is not None assert result.name == "John Doe" - def test_fetch_first_no_results(self, db_mock) -> None: + def test_fetch_first_no_results(self, db_mock: SqliterDB) -> None: """Test that fetch_first returns None when no results are found.""" # Define a simple model for the test @@ -121,7 +122,7 @@ class Meta: # Assert that fetch_first returns None when no results are found assert result is None - def test_fetch_first_multiple_results(self, db_mock) -> None: + def test_fetch_first_multiple_results(self, db_mock: SqliterDB) -> None: """Test fetch_first returns the first result as a model instance.""" # Define a simple model for the test @@ -145,7 +146,7 @@ class Meta: assert result is not None assert result.name == "John Doe" - def test_fetch_last_single_result(self, db_mock) -> None: + def test_fetch_last_single_result(self, db_mock: SqliterDB) -> None: """Test that fetch_last returns a single result as a model instance.""" # Define a simple model for the test @@ -168,7 +169,7 @@ class Meta: assert result is not None assert result.name == "John Doe" - def test_fetch_last_no_results(self, db_mock) -> None: + def test_fetch_last_no_results(self, db_mock: SqliterDB) -> None: """Test that fetch_last returns None when no results are found.""" # Define a simple model for the test @@ -187,7 +188,7 @@ class Meta: # Assert that fetch_last returns None when no results are found assert result is None - def fetch_last_multiple_results(self, db_mock) -> None: + def test_fetch_last_multiple_results(self, db_mock: SqliterDB) -> None: """Test that fetch_last returns the last result as a model instance.""" # Define a simple model for the test @@ -211,7 +212,7 @@ class Meta: assert result is not None assert result.name == "Jane Doe" - def test_filter_single_condition(self, db_mock) -> None: + def test_filter_single_condition(self, db_mock: SqliterDB) -> None: """Test filtering with a single condition.""" # Define a model for the test @@ -238,7 +239,7 @@ class Meta: assert results[0].name == "John Doe" assert results[0].age == 30 - def test_filter_multiple_conditions(self, db_mock) -> None: + def test_filter_multiple_conditions(self, db_mock: SqliterDB) -> None: """Test filtering with multiple conditions.""" # Define a model for the test @@ -267,7 +268,7 @@ class Meta: assert results[0].name == "John Doe" assert results[0].age == 30 - def test_filter_no_matching_results(self, db_mock) -> None: + def test_filter_no_matching_results(self, db_mock: SqliterDB) -> None: """Test filtering that returns no matching results.""" # Define a model for the test @@ -293,7 +294,7 @@ class Meta: # Assert that no results are returned assert len(results) == 0 - def test_filter_numeric_condition(self, db_mock) -> None: + def test_filter_numeric_condition(self, db_mock: SqliterDB) -> None: """Test filtering using a numeric condition.""" # Define a model for the test @@ -318,7 +319,7 @@ class Meta: assert results[0].name == "John Smith" assert results[0].age == 40 - def test_filter_multiple_results(self, db_mock) -> None: + def test_filter_multiple_results(self, db_mock: SqliterDB) -> None: """Test filtering that returns multiple matching results.""" # Define a model for the test @@ -344,7 +345,7 @@ class Meta: assert len(results) == 1 # Only one 'John Doe' assert results[0].name == "John Doe" - def test_filter_with_none_condition(self, db_mock) -> None: + def test_filter_with_none_condition(self, db_mock: SqliterDB) -> None: """Test filtering with None as a condition.""" # Define a model for the test @@ -368,7 +369,7 @@ class Meta: assert results[0].name == "Jane Doe" assert results[0].age is None - def test_limit(self, db_mock) -> None: + def test_limit(self, db_mock: SqliterDB) -> None: """Test that the limit method works as expected.""" # Define a simple model for the test @@ -392,7 +393,7 @@ class Meta: assert results[0].name == "John Doe" assert results[1].name == "Jane Doe" - def test_offset(self, db_mock) -> None: + def test_offset(self, db_mock: SqliterDB) -> None: """Test that the offset method works as expected.""" # Define a simple model for the test @@ -416,7 +417,7 @@ class Meta: assert results[0].name == "Jane Doe" assert results[1].name == "Jim Doe" - def test_limit_offset_order_combined(self, db_mock) -> None: + def test_limit_offset_order_combined(self, db_mock: SqliterDB) -> None: """Test that limit, offset, and order can work together.""" # Define a simple model for the test @@ -447,7 +448,7 @@ class Meta: assert results[0].name == "Jane Doe" assert results[1].name == "Jim Doe" - def test_limit_edge_cases(self, db_mock) -> None: + def test_limit_edge_cases(self, db_mock: SqliterDB) -> None: """Test limit with edge cases like zero and negative values.""" # Define a simple model for the test @@ -470,7 +471,7 @@ class Meta: results = db_mock.select(EdgeCaseTestModel).limit(-1).fetch_all() assert len(results) == 2 - def test_offset_exceeding_row_count(self, db_mock) -> None: + def test_offset_exceeding_row_count(self, db_mock: SqliterDB) -> None: """Test that an offset > the number of rows returns an empty result.""" # Insert multiple records for i in range(3): @@ -488,7 +489,7 @@ def test_offset_exceeding_row_count(self, db_mock) -> None: # Assert that the result is an empty list assert result == [] - def test_offset_edge_cases(self, db_mock) -> None: + def test_offset_edge_cases(self, db_mock: SqliterDB) -> None: """Test offset with edge cases like zero and negative values.""" # Define a simple model for the test @@ -517,7 +518,7 @@ class Meta: assert len(results) == 1 assert results[0].name == "Jane Doe" - def test_query_non_existent_table(self, db_mock) -> None: + def test_query_non_existent_table(self, db_mock: SqliterDB) -> None: """Test querying a non-existent table raises RecordFetchError.""" class NonExistentModel(ExampleModel): @@ -527,7 +528,7 @@ class Meta: with pytest.raises(RecordFetchError): db_mock.select(NonExistentModel).fetch_all() - def test_query_invalid_filter(self, db_mock) -> None: + def test_query_invalid_filter(self, db_mock: SqliterDB) -> None: """Test applying an invalid filter raises RecordFetchError.""" # Ensure the table is created db_mock.create_table(ExampleModel) @@ -538,7 +539,7 @@ def test_query_invalid_filter(self, db_mock) -> None: non_existent_field="value" ).fetch_all() - def test_query_valid_filter(self, db_mock) -> None: + def test_query_valid_filter(self, db_mock: SqliterDB) -> None: """Test that valid filter fields do not raise InvalidFilterError.""" # Ensure the table is created db_mock.create_table(ExampleModel) @@ -549,7 +550,7 @@ def test_query_valid_filter(self, db_mock) -> None: except InvalidFilterError: pytest.fail("Valid field raised InvalidFilterError unexpectedly") - def test_query_mixed_valid_invalid_filter(self, db_mock) -> None: + def test_query_mixed_valid_invalid_filter(self, db_mock: SqliterDB) -> None: """Test a mix of valid and invalid fields raises InvalidFilterError.""" # Ensure the table is created db_mock.create_table(ExampleModel) @@ -560,7 +561,9 @@ def test_query_mixed_valid_invalid_filter(self, db_mock) -> None: name="Valid Name", non_existent_field="Invalid" ).fetch_all() - def test_filter_rejects_list_for_equality_operators(self, db_mock) -> None: + def test_filter_rejects_list_for_equality_operators( + self, db_mock: SqliterDB + ) -> None: """Test that equality operators reject list values.""" db_mock.create_table(ExampleModel) @@ -577,7 +580,7 @@ def test_filter_rejects_list_for_equality_operators(self, db_mock) -> None: ).fetch_all() def test_filter_rejects_list_for_comparison_operators( - self, db_mock + self, db_mock: SqliterDB ) -> None: """Test that comparison operators reject list values.""" db_mock.create_table(ExampleModel) @@ -602,7 +605,9 @@ def test_filter_rejects_list_for_comparison_operators( name__gte=["a", "b"] ).fetch_all() - def test_filter_accepts_list_for_in_operators(self, db_mock) -> None: + def test_filter_accepts_list_for_in_operators( + self, db_mock: SqliterDB + ) -> None: """Test that __in and __not_in accept and require list values.""" db_mock.create_table(ExampleModel) @@ -645,7 +650,9 @@ def test_filter_accepts_list_for_in_operators(self, db_mock) -> None: name__not_in="Alice" ).fetch_all() - def test_qualify_base_filter_clause_for_model_field(self, db_mock) -> None: + def test_qualify_base_filter_clause_for_model_field( + self, db_mock: SqliterDB + ) -> None: """Test base model fields are qualified in JOIN filter clauses.""" query = db_mock.select(ExampleModel) @@ -653,7 +660,9 @@ def test_qualify_base_filter_clause_for_model_field(self, db_mock) -> None: assert qualified == 't0."pk" IN (?, ?)' - def test_qualify_base_filter_clause_no_regex_match(self, db_mock) -> None: + def test_qualify_base_filter_clause_no_regex_match( + self, db_mock: SqliterDB + ) -> None: """Test clauses without a leading identifier are unchanged.""" query = db_mock.select(ExampleModel) clause = '"pk" IN (?, ?)' @@ -662,7 +671,9 @@ def test_qualify_base_filter_clause_no_regex_match(self, db_mock) -> None: assert qualified == clause - def test_qualify_base_filter_clause_non_model_field(self, db_mock) -> None: + def test_qualify_base_filter_clause_non_model_field( + self, db_mock: SqliterDB + ) -> None: """Test unknown fields are not qualified in filter clauses.""" query = db_mock.select(ExampleModel) clause = "unknown_field = ?" @@ -671,7 +682,9 @@ def test_qualify_base_filter_clause_non_model_field(self, db_mock) -> None: assert qualified == clause - def test_qualify_base_filter_clause_already_aliased(self, db_mock) -> None: + def test_qualify_base_filter_clause_already_aliased( + self, db_mock: SqliterDB + ) -> None: """Test already-aliased clauses are left unchanged.""" query = db_mock.select(ExampleModel) clause = 't1."name" LIKE ?' @@ -680,13 +693,15 @@ def test_qualify_base_filter_clause_already_aliased(self, db_mock) -> None: assert qualified == clause - def test_qualify_base_field_name_wrapper(self, db_mock) -> None: + def test_qualify_base_field_name_wrapper(self, db_mock: SqliterDB) -> None: """Test wrapper method forwards to qualified base-field rendering.""" query = db_mock.select(ExampleModel) assert query._qualify_base_field_name("pk") == 't0."pk"' assert query._qualify_base_field_name("unknown") == "unknown" - def test_fetch_result_with_list_of_tuples(self, mocker) -> None: + def test_fetch_result_with_list_of_tuples( + self, mocker: MockerFixture + ) -> None: """Test _fetch_result when _execute_query returns list of tuples.""" # ensure we get a dependable timestamp mocker.patch("time.time", return_value=1234567890) @@ -730,7 +745,7 @@ def test_exclude_pk_raises_valueerror(self) -> None: with pytest.raises(ValueError, match=match_str): db.select(ExampleModel).exclude(["pk"]) - def test_delete_all_records(self, db_mock) -> None: + def test_delete_all_records(self, db_mock: SqliterDB) -> None: """Test delete() removes all records when no filters are applied.""" # Define a simple model for the test @@ -753,7 +768,7 @@ class Meta: assert deleted_count == 3 assert db_mock.select(DeleteTestModel).count() == 0 - def test_delete_filtered_records(self, db_mock) -> None: + def test_delete_filtered_records(self, db_mock: SqliterDB) -> None: """Test that delete() removes only records matching the filter.""" # Define a simple model for the test @@ -785,7 +800,7 @@ class Meta: assert remaining[0].name == "John" assert remaining[1].name == "Jane" - def test_delete_no_matches(self, db_mock) -> None: + def test_delete_no_matches(self, db_mock: SqliterDB) -> None: """Test that delete() returns 0 when no records match the filter.""" # Define a simple model for the test @@ -808,7 +823,7 @@ class Meta: assert deleted_count == 0 assert db_mock.select(DeleteTestModel).count() == 1 - def test_delete_with_complex_filters(self, db_mock) -> None: + def test_delete_with_complex_filters(self, db_mock: SqliterDB) -> None: """Test deleting records with multiple filter conditions.""" # Define a model for the test @@ -852,7 +867,7 @@ class Meta: assert "Alice" not in remaining_names assert set(remaining_names) == {"John", "Jane"} - def test_delete_with_null_values(self, db_mock) -> None: + def test_delete_with_null_values(self, db_mock: SqliterDB) -> None: """Test deleting records with NULL value conditions.""" # Define a model for the test @@ -885,7 +900,7 @@ class Meta: assert all(record.optional_field is not None for record in remaining) assert {record.name for record in remaining} == {"John", "Bob"} - def test_delete_database_error(self, db_mock) -> None: + def test_delete_database_error(self, db_mock: SqliterDB) -> None: """Test that database errors during delete are handled properly.""" # Define a model for the test @@ -909,7 +924,7 @@ class Meta: # Verify the error message assert "error_delete_table" in str(exc.value) - def test_delete_ignores_limit_offset(self, db_mock) -> None: + def test_delete_ignores_limit_offset(self, db_mock: SqliterDB) -> None: """Test that delete operation ignores LIMIT and OFFSET clauses.""" # Define a model for the test @@ -933,7 +948,9 @@ class Meta: assert deleted_count == 5 assert db_mock.select(LimitOffsetModel).count() == 0 - def test_delete_with_auto_commit(self, db_mock, mocker) -> None: + def test_delete_with_auto_commit( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test delete behavior with auto_commit enabled and disabled.""" # Create a model for the test @@ -975,7 +992,7 @@ class Meta: # Records should be gone assert db_mock.select(AutoCommitModel).count() == 0 - def test_delete_ignores_order(self, db_mock) -> None: + def test_delete_ignores_order(self, db_mock: SqliterDB) -> None: """Test that delete operation ignores ORDER BY clause.""" # Define a model for the test @@ -1007,7 +1024,7 @@ class Meta: assert remaining[0].name == "B" assert remaining[0].value == 1 - def test_delete_empty_table(self, db_mock) -> None: + def test_delete_empty_table(self, db_mock: SqliterDB) -> None: """Test deleting from an empty table.""" # Define a model for the test @@ -1027,7 +1044,9 @@ class Meta: assert deleted_count == 0 assert db_mock.select(EmptyModel).count() == 0 - def test_delete_with_debug_logging(self, db_mock, mocker) -> None: + def test_delete_with_debug_logging( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that delete operation logs SQL when debug is enabled.""" # Define a model for the test diff --git a/tests/test_sqliter.py b/tests/test_sqliter.py index 6429f33b..ca63739a 100644 --- a/tests/test_sqliter.py +++ b/tests/test_sqliter.py @@ -1,6 +1,7 @@ """Test suite for the 'sqliter' library.""" import pytest +from pytest_mock import MockerFixture from sqliter import SqliterDB from sqliter.exceptions import ( @@ -67,7 +68,7 @@ def test_data_lost_when_auto_commit_disabled(self) -> None: with pytest.raises(RecordFetchError): db.get(ExampleModel, result.pk) - def test_create_table(self, db_mock) -> None: + def test_create_table(self, db_mock: SqliterDB) -> None: """Test table creation.""" with db_mock.connect() as conn: cursor = conn.cursor() @@ -76,12 +77,12 @@ def test_create_table(self, db_mock) -> None: assert len(tables) == 2 assert ("test_table",) in tables - def test_close_connection(self, db_mock) -> None: + def test_close_connection(self, db_mock: SqliterDB) -> None: """Test closing the connection.""" db_mock.close() assert db_mock.conn is None - def test_commit_changes(self, mocker) -> None: + def test_commit_changes(self, mocker: MockerFixture) -> None: """Test committing changes to the database.""" db = SqliterDB(":memory:", auto_commit=False) db.create_table(ExampleModel) @@ -95,7 +96,9 @@ def test_commit_changes(self, mocker) -> None: assert mock_conn.commit.called - def test_create_table_with_default_auto_increment(self, db_mock) -> None: + def test_create_table_with_default_auto_increment( + self, db_mock: SqliterDB + ) -> None: """Test table creation with auto-incrementing primary key.""" class AutoIncrementModel(BaseDBModel): @@ -118,7 +121,7 @@ class Meta: assert table_info[0][2] == "INTEGER" # Column type assert table_info[0][5] == 1 # Primary key flag - def test_default_table_name(self, db_mock) -> None: + def test_default_table_name(self, db_mock: SqliterDB) -> None: """Test the default table name generation. It should default to the class name in lowercase, plural form. @@ -133,7 +136,9 @@ class Meta: # Verify that get_table_name defaults to class name in lowercase assert DefaultNameModel.get_table_name() == "default_names" - def test_get_table_name_fallback_without_inflect(self, mocker) -> None: + def test_get_table_name_fallback_without_inflect( + self, mocker: MockerFixture + ) -> None: """Test get_table_name falls back to manual plural without 'inflect.""" # Mock the inflect import to raise ImportError for `inflect` mocker.patch.dict("sys.modules", {"inflect": None}) @@ -144,7 +149,9 @@ class UserModel(BaseDBModel): table_name = UserModel.get_table_name() assert table_name == "users" # Fallback logic should add 's' - def test_get_table_name_no_double_s_without_inflect(self, mocker) -> None: + def test_get_table_name_no_double_s_without_inflect( + self, mocker: MockerFixture + ) -> None: """Test get_table_name doesn't add extra 's' if already there.""" # Mock the sys.modules to simulate 'inflect' being unavailable mocker.patch.dict("sys.modules", {"inflect": None}) @@ -170,7 +177,7 @@ class PersonModel(BaseDBModel): ", or ignore this failure." ) - def test_insert_license(self, db_mock) -> None: + def test_insert_license(self, db_mock: SqliterDB) -> None: """Test inserting a license into the database.""" test_model = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -186,7 +193,7 @@ def test_insert_license(self, db_mock) -> None: assert result[4] == "MIT License" assert result[5] == "MIT License Content" - def test_fetch_license(self, db_mock) -> None: + def test_fetch_license(self, db_mock: SqliterDB) -> None: """Test fetching a license by primary key.""" test_model = ExampleModel( slug="gpl", name="GPL License", content="GPL License Content" @@ -199,7 +206,7 @@ def test_fetch_license(self, db_mock) -> None: assert fetched_license.name == "GPL License" assert fetched_license.content == "GPL License Content" - def test_update(self, db_mock) -> None: + def test_update(self, db_mock: SqliterDB) -> None: """Test updating an existing license.""" test_model = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -212,9 +219,10 @@ def test_update(self, db_mock) -> None: # Fetch and check if updated fetched_license = db_mock.get(ExampleModel, result.pk) + assert fetched_license is not None assert fetched_license.content == "Updated MIT License Content" - def test_delete(self, db_mock) -> None: + def test_delete(self, db_mock: SqliterDB) -> None: """Test deleting a license.""" test_model = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -228,7 +236,7 @@ def test_delete(self, db_mock) -> None: fetched_license = db_mock.get(ExampleModel, result.pk) assert fetched_license is None - def test_select_filter(self, db_mock) -> None: + def test_select_filter(self, db_mock: SqliterDB) -> None: """Test filtering licenses using the QueryBuilder.""" license1 = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -246,7 +254,7 @@ def test_select_filter(self, db_mock) -> None: assert len(filtered) == 1 assert filtered[0].slug == "gpl" - def test_query_fetch_first(self, db_mock) -> None: + def test_query_fetch_first(self, db_mock: SqliterDB) -> None: """Test fetching the first record.""" license1 = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -258,9 +266,10 @@ def test_query_fetch_first(self, db_mock) -> None: db_mock.insert(license2) first_record = db_mock.select(ExampleModel).fetch_first() + assert first_record is not None assert first_record.slug == "mit" - def test_query_fetch_last(self, db_mock) -> None: + def test_query_fetch_last(self, db_mock: SqliterDB) -> None: """Test fetching the last record.""" license1 = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -272,9 +281,10 @@ def test_query_fetch_last(self, db_mock) -> None: db_mock.insert(license2) last_record = db_mock.select(ExampleModel).fetch_last() + assert last_record is not None assert last_record.slug == "gpl" - def test_count_records(self, db_mock) -> None: + def test_count_records(self, db_mock: SqliterDB) -> None: """Test counting records in the database.""" license1 = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -288,7 +298,7 @@ def test_count_records(self, db_mock) -> None: count = db_mock.select(ExampleModel).count() assert count == 2 - def test_exists_record(self, db_mock) -> None: + def test_exists_record(self, db_mock: SqliterDB) -> None: """Test checking if a record exists.""" license1 = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" @@ -298,7 +308,9 @@ def test_exists_record(self, db_mock) -> None: exists = db_mock.select(ExampleModel).filter(slug="mit").exists() assert exists - def test_transaction_commit(self, db_mock, mocker) -> None: + def test_transaction_commit( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test if auto_commit works correctly when enabled.""" # Mock the commit method on the connection mock_conn = mocker.MagicMock() @@ -320,7 +332,7 @@ def test_transaction_commit(self, db_mock, mocker) -> None: # Ensure commit was called only once, when the context manager exited. assert mock_conn.commit.call_count == 1 - def test_transaction_manual_commit(self, mocker) -> None: + def test_transaction_manual_commit(self, mocker: MockerFixture) -> None: """Test context-manager commit when auto_commit is set to False. Regardless of the auto_commit setting, the context manager should commit @@ -348,7 +360,7 @@ def test_transaction_manual_commit(self, mocker) -> None: # After leaving the context, commit should now be called mock_conn.commit.assert_called_once() - def test_update_existing_record(self, db_mock) -> None: + def test_update_existing_record(self, db_mock: SqliterDB) -> None: """Test that updating an existing record works correctly.""" # Insert an example record example_model = ExampleModel( @@ -365,7 +377,7 @@ def test_update_existing_record(self, db_mock) -> None: assert updated_record is not None assert updated_record.content == "Updated Content" - def test_update_non_existing_record(self, db_mock) -> None: + def test_update_non_existing_record(self, db_mock: SqliterDB) -> None: """Test updating a non-existing record raises RecordNotFoundError.""" # Create an example record that is not inserted into the DB example_model = ExampleModel( @@ -383,7 +395,7 @@ def test_update_non_existing_record(self, db_mock) -> None: exc_info.value ) - def test_get_non_existent_table(self, db_mock) -> None: + def test_get_non_existent_table(self, db_mock: SqliterDB) -> None: """Test fetching from a non-existent table raises RecordFetchError.""" class NonExistentModel(ExampleModel): @@ -391,19 +403,19 @@ class Meta: table_name = "non_existent_table" # A table that doesn't exist with pytest.raises(RecordFetchError): - db_mock.get(NonExistentModel, "non_existent_key") + db_mock.get(NonExistentModel, -1) - def test_get_record_no_result(self, db_mock) -> None: + def test_get_record_no_result(self, db_mock: SqliterDB) -> None: """Test fetching a non-existent record returns None.""" - result = db_mock.get(ExampleModel, "non_existent_key") + result = db_mock.get(ExampleModel, -1) assert result is None - def test_delete_non_existent_record(self, db_mock) -> None: + def test_delete_non_existent_record(self, db_mock: SqliterDB) -> None: """Test that trying to delete a non-existent record raises exception.""" with pytest.raises(RecordNotFoundError): - db_mock.delete(ExampleModel, "non_existent_key") + db_mock.delete(ExampleModel, -1) - def test_delete_existing_record(self, db_mock) -> None: + def test_delete_existing_record(self, db_mock: SqliterDB) -> None: """Test that a record is deleted successfully.""" # Insert a record first test_model = ExampleModel( @@ -415,8 +427,8 @@ def test_delete_existing_record(self, db_mock) -> None: db_mock.delete(ExampleModel, result.pk) # Fetch the deleted record to confirm it's gone - result = db_mock.get(ExampleModel, result.pk) - assert result is None + deleted_result = db_mock.get(ExampleModel, result.pk) + assert deleted_result is None def test_select_with_exclude_single_field( self, @@ -500,7 +512,9 @@ def test_error_when_no_db_name_and_not_memory(self) -> None: with pytest.raises(ValueError, match="Database name must be provided"): SqliterDB(memory=False) - def test_file_is_created_when_filename_is_provided(self, mocker) -> None: + def test_file_is_created_when_filename_is_provided( + self, mocker: MockerFixture + ) -> None: """Test that sqlite3.connect is called with the correct file path.""" mock_connect = mocker.patch("sqlite3.connect") @@ -511,7 +525,9 @@ def test_file_is_created_when_filename_is_provided(self, mocker) -> None: # Check if sqlite3.connect was called with the correct filename mock_connect.assert_called_with(db_filename) - def test_memory_database_no_file_created(self, mocker) -> None: + def test_memory_database_no_file_created( + self, mocker: MockerFixture + ) -> None: """Test sqlite3.connect is called with ':memory:' when memory=True.""" mock_connect = mocker.patch("sqlite3.connect") @@ -522,7 +538,7 @@ def test_memory_database_no_file_created(self, mocker) -> None: # DB mock_connect.assert_called_with(":memory:") - def test_memory_db_ignores_filename(self, mocker) -> None: + def test_memory_db_ignores_filename(self, mocker: MockerFixture) -> None: """Test memory=True igores any filename, creating an in-memory DB.""" mock_connect = mocker.patch("sqlite3.connect") @@ -534,7 +550,7 @@ def test_memory_db_ignores_filename(self, mocker) -> None: # filename mock_connect.assert_called_with(":memory:") - def test_complex_model_field_types(self, db_mock) -> None: + def test_complex_model_field_types(self, db_mock: SqliterDB) -> None: """Test that the table is created with the correct field types.""" # Create table based on ComplexModel db_mock.create_table(ComplexModel) @@ -569,7 +585,7 @@ def test_complex_model_field_types(self, db_mock) -> None: f"but got {column_type}" ) - def test_complex_model_primary_key(self, db_mock) -> None: + def test_complex_model_primary_key(self, db_mock: SqliterDB) -> None: """Test that the primary key is correctly created for ComplexModel.""" # Create table based on ComplexModel db_mock.create_table(ComplexModel) @@ -598,7 +614,7 @@ def test_complex_model_primary_key(self, db_mock) -> None: f"{primary_key_column[2]}" ) - def test_reset_database_on_init(self, temp_db_path) -> None: + def test_reset_database_on_init(self, temp_db_path: str) -> None: """Test that the database is reset when reset=True is passed.""" class TestModel(BaseDBModel): @@ -620,7 +636,9 @@ class Meta: with pytest.raises(RecordFetchError): db_reset.select(TestModel).fetch_all() - def test_reset_database_preserves_connection(self, temp_db_path) -> None: + def test_reset_database_preserves_connection( + self, temp_db_path: str + ) -> None: """Test that resetting the database doesn't break the connection.""" class TestModel(BaseDBModel): @@ -639,7 +657,9 @@ class Meta: result = db.select(TestModel).fetch_all() assert len(result) == 1 - def test_reset_database_with_multiple_tables(self, temp_db_path) -> None: + def test_reset_database_with_multiple_tables( + self, temp_db_path: str + ) -> None: """Test that reset drops all tables in the database.""" class TestModel1(BaseDBModel): @@ -671,7 +691,7 @@ class Meta: with pytest.raises(RecordFetchError): db_reset.select(TestModel2).fetch_all() - def test_create_table_exists_ok_true(self, db_mock) -> None: + def test_create_table_exists_ok_true(self, db_mock: SqliterDB) -> None: """Test creating a table with exists_ok=True (default behavior).""" # First creation should succeed db_mock.create_table(ExistOkModel) @@ -682,7 +702,7 @@ def test_create_table_exists_ok_true(self, db_mock) -> None: except TableCreationError as e: pytest.fail(f"create_table raised {type(e).__name__} unexpectedly!") - def test_create_table_exists_ok_false(self, db_mock) -> None: + def test_create_table_exists_ok_false(self, db_mock: SqliterDB) -> None: """Test creating a table with exists_ok=False.""" # First creation should succeed db_mock.create_table(ExistOkModel) @@ -713,7 +733,9 @@ class Meta: # Clean up new_db.close() - def test_create_table_sql_generation(self, db_mock, mocker) -> None: + def test_create_table_sql_generation( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test SQL generation for table creation based on exists_ok value.""" mock_cursor = mocker.MagicMock() mocker.patch.object( @@ -794,14 +816,16 @@ class Meta: # Clean up db.close() - def test_create_table_force_and_exists_ok(self, db_mock) -> None: + def test_create_table_force_and_exists_ok(self, db_mock: SqliterDB) -> None: """Test interaction between force and exists_ok parameters.""" # force=True should take precedence over exists_ok=False db_mock.create_table(ExistOkModel) db_mock.create_table(ExistOkModel, exists_ok=False, force=True) # This should not raise an error - def test_create_table_sql_generation_force(self, db_mock, mocker) -> None: + def test_create_table_sql_generation_force( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test SQL generation for table creation with force=True.""" mock_cursor = mocker.MagicMock() mocker.patch.object( diff --git a/tests/test_timestamps.py b/tests/test_timestamps.py index 3316d309..cdc75358 100644 --- a/tests/test_timestamps.py +++ b/tests/test_timestamps.py @@ -2,13 +2,18 @@ from datetime import datetime, timezone +from pytest_mock import MockerFixture + +from sqliter.sqliter import SqliterDB from tests.conftest import ExampleModel class TestTimestamps: """Test the `created_at` and `updated_at` timestamps.""" - def test_insert_timestamps(self, db_mock, mocker) -> None: + def test_insert_timestamps( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test both timestamps are set on record insert.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) @@ -25,7 +30,9 @@ def test_insert_timestamps(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1234567890 assert returned_instance.updated_at == 1234567890 - def test_update_timestamps(self, db_mock, mocker) -> None: + def test_update_timestamps( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that the `updated_at` timestamp is updated on record update.""" # Mock time.time() to return a fixed timestamp for the update mocker.patch("time.time", return_value=1234567890) @@ -51,7 +58,9 @@ def test_update_timestamps(self, db_mock, mocker) -> None: returned_instance.updated_at == 1234567891 ) # Should be updated to the new timestamp - def test_insert_with_provided_timestamps(self, db_mock, mocker) -> None: + def test_insert_with_provided_timestamps( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that user-provided timestamps are respected on insert.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) @@ -74,7 +83,9 @@ def test_insert_with_provided_timestamps(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1111111111 assert returned_instance.updated_at == 1111111111 - def test_insert_with_default_timestamps(self, db_mock, mocker) -> None: + def test_insert_with_default_timestamps( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that timestamps are set when created_at and updated_at are 0.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) @@ -95,7 +106,9 @@ def test_insert_with_default_timestamps(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1234567890 assert returned_instance.updated_at == 1234567890 - def test_insert_with_mixed_timestamps(self, db_mock, mocker) -> None: + def test_insert_with_mixed_timestamps( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test a mix of user-provided and default timestamps work on insert.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) @@ -119,7 +132,9 @@ def test_insert_with_mixed_timestamps(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1111111111 assert returned_instance.updated_at == 1234567890 - def test_update_timestamps_on_change(self, db_mock, mocker) -> None: + def test_update_timestamps_on_change( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that only `updated_at` changes on update.""" # Mock time.time() to return a fixed timestamp for the insert mocker.patch("time.time", return_value=1234567890) @@ -141,7 +156,9 @@ def test_update_timestamps_on_change(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1234567890 assert returned_instance.updated_at == 1234567891 - def test_no_change_if_timestamps_already_set(self, db_mock, mocker) -> None: + def test_no_change_if_timestamps_already_set( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test timestamps are not modified if already set during insert.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) @@ -164,7 +181,9 @@ def test_no_change_if_timestamps_already_set(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1111111111 assert returned_instance.updated_at == 1111111111 - def test_override_but_no_timestamps_provided(self, db_mock, mocker) -> None: + def test_override_but_no_timestamps_provided( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test missing timestamps always set to current time. Even with `timestamp_override=True`. @@ -191,7 +210,9 @@ def test_override_but_no_timestamps_provided(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1234567890 assert returned_instance.updated_at == 1234567890 - def test_partial_override_with_zero(self, db_mock, mocker) -> None: + def test_partial_override_with_zero( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test changing `updated_at` only on create. When `timestamp_override=True @@ -218,7 +239,9 @@ def test_partial_override_with_zero(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1111111111 assert returned_instance.updated_at == 1234567890 - def test_insert_with_override_disabled(self, db_mock, mocker) -> None: + def test_insert_with_override_disabled( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that timestamp_override=False ignores provided timestamps.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) @@ -241,7 +264,9 @@ def test_insert_with_override_disabled(self, db_mock, mocker) -> None: assert returned_instance.created_at == 1234567890 assert returned_instance.updated_at == 1234567890 - def test_time_is_in_utc(self, db_mock, mocker) -> None: + def test_time_is_in_utc( + self, db_mock: SqliterDB, mocker: MockerFixture + ) -> None: """Test that timestamps generated with time.time() are in UTC.""" # Mock time.time() to return a fixed timestamp mocker.patch("time.time", return_value=1234567890) diff --git a/tests/test_transaction_rollback.py b/tests/test_transaction_rollback.py index fddc0377..cd251bbb 100644 --- a/tests/test_transaction_rollback.py +++ b/tests/test_transaction_rollback.py @@ -6,6 +6,7 @@ import sqlite3 from contextlib import suppress +from pathlib import Path import pytest @@ -34,7 +35,7 @@ def _raise_error() -> None: class TestTransactionRollback: """Test transaction rollback behavior.""" - def test_insert_rollback_on_exception(self, tmp_path) -> None: + def test_insert_rollback_on_exception(self, tmp_path: Path) -> None: """Verify that insert is rolled back when an exception occurs.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -54,7 +55,7 @@ def test_insert_rollback_on_exception(self, tmp_path) -> None: assert len(result) == 0, "Insert should have been rolled back" - def test_update_rollback_on_exception(self, tmp_path) -> None: + def test_update_rollback_on_exception(self, tmp_path: Path) -> None: """Verify that update is rolled back when an exception occurs.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -82,7 +83,7 @@ def test_update_rollback_on_exception(self, tmp_path) -> None: f"Update should have been rolled back, got {result.quantity}" ) - def test_delete_rollback_on_exception(self, tmp_path) -> None: + def test_delete_rollback_on_exception(self, tmp_path: Path) -> None: """Verify that delete is rolled back when an exception occurs.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -104,7 +105,9 @@ def test_delete_rollback_on_exception(self, tmp_path) -> None: assert result is not None, "Delete should have been rolled back" - def test_query_builder_delete_rollback_on_exception(self, tmp_path) -> None: + def test_query_builder_delete_rollback_on_exception( + self, tmp_path: Path + ) -> None: """Verify that QueryBuilder.delete is rolled back on exception.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -128,7 +131,7 @@ def test_query_builder_delete_rollback_on_exception(self, tmp_path) -> None: "QueryBuilder delete should have been rolled back" ) - def test_multiple_operations_rollback(self, tmp_path) -> None: + def test_multiple_operations_rollback(self, tmp_path: Path) -> None: """Verify that multiple operations all rollback together.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -168,7 +171,7 @@ def test_multiple_operations_rollback(self, tmp_path) -> None: assert all_items[1].name == "Gadget" assert all_items[1].quantity == 20, "Gadget should not be deleted" - def test_transaction_commit_success(self, tmp_path) -> None: + def test_transaction_commit_success(self, tmp_path: Path) -> None: """Verify that successful transaction commits all changes.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -188,7 +191,9 @@ def test_transaction_commit_success(self, tmp_path) -> None: assert len(result) == 2, "Both inserts should have been committed" - def test_no_intermediate_commits_in_transaction(self, tmp_path) -> None: + def test_no_intermediate_commits_in_transaction( + self, tmp_path: Path + ) -> None: """Verify that data isn't committed before transaction ends.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -219,7 +224,9 @@ def test_no_intermediate_commits_in_transaction(self, tmp_path) -> None: assert count == 1, "Data should be visible after transaction commits" - def test_context_manager_sets_transaction_flag(self, tmp_path) -> None: + def test_context_manager_sets_transaction_flag( + self, tmp_path: Path + ) -> None: """Verify that context manager correctly sets _in_transaction flag.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -238,7 +245,7 @@ def test_context_manager_sets_transaction_flag(self, tmp_path) -> None: db.close() assert len(result) == 1 - def test_autocommit_false_still_rolls_back(self, tmp_path) -> None: + def test_autocommit_false_still_rolls_back(self, tmp_path: Path) -> None: """Verify rollback works even with auto_commit=False.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file), auto_commit=False) @@ -259,7 +266,7 @@ def test_autocommit_false_still_rolls_back(self, tmp_path) -> None: assert len(result) == 0, "Insert should have been rolled back" - def test_read_operation_in_transaction(self, tmp_path) -> None: + def test_read_operation_in_transaction(self, tmp_path: Path) -> None: """Verify read operations work inside a transaction.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) @@ -274,7 +281,7 @@ def test_read_operation_in_transaction(self, tmp_path) -> None: assert len(result) == 1 assert result[0].name == "Widget" - def test_exception_type_preserved(self, tmp_path) -> None: + def test_exception_type_preserved(self, tmp_path: Path) -> None: """Verify that the original exception is re-raised after rollback.""" db_file = tmp_path / "test_rollback.db" db = SqliterDB(db_filename=str(db_file)) diff --git a/tests/test_unique.py b/tests/test_unique.py index 67b4ec49..841a682c 100644 --- a/tests/test_unique.py +++ b/tests/test_unique.py @@ -4,6 +4,7 @@ from typing import Annotated, Union import pytest +from pytest_mock import MockerFixture from sqliter import SqliterDB from sqliter.exceptions import RecordInsertionError, RecordUpdateError @@ -68,7 +69,9 @@ class User(BaseDBModel): assert "UNIQUE constraint failed: users.name" in str(excinfo.value) - def test_unique_constraint_sql_generation(self, mocker) -> None: + def test_unique_constraint_sql_generation( + self, mocker: MockerFixture + ) -> None: """Test that the correct SQL for the Unique constraint is generated.""" class User(BaseDBModel): diff --git a/tests/tui/test_app.py b/tests/tui/test_app.py index 2716b27b..420fa69f 100644 --- a/tests/tui/test_app.py +++ b/tests/tui/test_app.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any, cast + import pytest from textual.css.query import NoMatches from textual.widgets import Button, Footer, Header, Tree @@ -16,9 +18,12 @@ OutputDisplay, ) +if TYPE_CHECKING: + from pytest_mock import MockerFixture + @pytest.fixture -def registered_demo(reset_demo_registry) -> Demo: +def registered_demo(reset_demo_registry: None) -> Demo: """Register a minimal demo and category for testing. Returns: @@ -42,7 +47,7 @@ class TestSQLiterDemoAppComposition: """Test app composition and layout.""" @pytest.mark.asyncio - async def test_app_composition(self, reset_demo_registry) -> None: + async def test_app_composition(self, reset_demo_registry: None) -> None: """Test that all main widgets are rendered.""" demo = Demo( id="test", @@ -75,7 +80,7 @@ async def test_app_composition(self, reset_demo_registry) -> None: assert output_display is not None @pytest.mark.asyncio - async def test_header_exists(self, registered_demo) -> None: + async def test_header_exists(self, registered_demo: Demo) -> None: """Test that Header widget exists.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -83,7 +88,7 @@ async def test_header_exists(self, registered_demo) -> None: assert header is not None @pytest.mark.asyncio - async def test_footer_exists(self, registered_demo) -> None: + async def test_footer_exists(self, registered_demo: Demo) -> None: """Test that Footer widget exists.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -91,7 +96,7 @@ async def test_footer_exists(self, registered_demo) -> None: assert footer is not None @pytest.mark.asyncio - async def test_demo_list_exists(self, registered_demo) -> None: + async def test_demo_list_exists(self, registered_demo: Demo) -> None: """Test that DemoList widget exists.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -99,7 +104,7 @@ async def test_demo_list_exists(self, registered_demo) -> None: assert demo_list is not None @pytest.mark.asyncio - async def test_code_display_exists(self, registered_demo) -> None: + async def test_code_display_exists(self, registered_demo: Demo) -> None: """Test that CodeDisplay widget exists.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -107,7 +112,7 @@ async def test_code_display_exists(self, registered_demo) -> None: assert code_display is not None @pytest.mark.asyncio - async def test_output_display_exists(self, registered_demo) -> None: + async def test_output_display_exists(self, registered_demo: Demo) -> None: """Test that OutputDisplay widget exists.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -115,7 +120,7 @@ async def test_output_display_exists(self, registered_demo) -> None: assert output_display is not None @pytest.mark.asyncio - async def test_buttons_exist(self, registered_demo) -> None: + async def test_buttons_exist(self, registered_demo: Demo) -> None: """Test that Run and Clear buttons exist.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -132,7 +137,7 @@ class TestSQLiterDemoAppFocus: """Test focus and navigation.""" @pytest.mark.asyncio - async def test_initial_focus_on_tree(self, registered_demo) -> None: + async def test_initial_focus_on_tree(self, registered_demo: Demo) -> None: """Test that the tree is focused on app mount.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -146,7 +151,7 @@ class TestSQLiterDemoAppDemoSelection: @pytest.mark.asyncio async def test_demo_selection_updates_code( - self, reset_demo_registry + self, reset_demo_registry: None ) -> None: """Test that selecting a demo updates the code display.""" demo = Demo( @@ -172,7 +177,9 @@ async def test_demo_selection_updates_code( assert "print('hello')" in code_display.code @pytest.mark.asyncio - async def test_demo_selection_stores_current(self, registered_demo) -> None: + async def test_demo_selection_stores_current( + self, registered_demo: Demo + ) -> None: """Test that selecting a demo stores it as current.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -187,7 +194,9 @@ class TestSQLiterDemoAppExecution: """Test demo execution functionality.""" @pytest.mark.asyncio - async def test_run_demo_with_selection(self, reset_demo_registry) -> None: + async def test_run_demo_with_selection( + self, reset_demo_registry: None + ) -> None: """Test running a demo when one is selected.""" demo = Demo( id="test", @@ -219,7 +228,7 @@ async def test_run_demo_with_selection(self, reset_demo_registry) -> None: @pytest.mark.asyncio async def test_run_demo_without_selection( self, - registered_demo, + registered_demo: Demo, ) -> None: """Test running a demo without selecting one first.""" app = SQLiterDemoApp() @@ -231,13 +240,14 @@ async def test_run_demo_without_selection( output_display = app.query_one("#output-display", OutputDisplay) # Should show error message - content = str( - output_display.query_one("#output-content").content - ).lower() + output_widget = cast( + "Any", output_display.query_one("#output-content") + ) + content = str(output_widget.content).lower() assert "select a demo" in content @pytest.mark.asyncio - async def test_clear_output(self, registered_demo) -> None: + async def test_clear_output(self, registered_demo: Demo) -> None: """Test clearing the output display.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -252,7 +262,7 @@ async def test_clear_output(self, registered_demo) -> None: await pilot.pause() # Should be back to placeholder - content = output_display.query_one("#output-content") + content = cast("Any", output_display.query_one("#output-content")) assert "Run a demo" in str(content.content) @@ -260,7 +270,7 @@ class TestSQLiterDemoAppHelpScreen: """Test help screen functionality.""" @pytest.mark.asyncio - async def test_help_screen_composition(self, registered_demo) -> None: + async def test_help_screen_composition(self, registered_demo: Demo) -> None: """Test that help screen can be shown.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -273,7 +283,7 @@ async def test_help_screen_composition(self, registered_demo) -> None: assert len(app.screen_stack) > 1 @pytest.mark.asyncio - async def test_help_key_opens_help(self, registered_demo) -> None: + async def test_help_key_opens_help(self, registered_demo: Demo) -> None: """Test that '?' key opens help.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -284,7 +294,7 @@ async def test_help_key_opens_help(self, registered_demo) -> None: assert len(app.screen_stack) > 1 @pytest.mark.asyncio - async def test_f1_opens_help(self, registered_demo) -> None: + async def test_f1_opens_help(self, registered_demo: Demo) -> None: """Test that F1 key opens help.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -296,7 +306,7 @@ async def test_f1_opens_help(self, registered_demo) -> None: assert len(app.screen_stack) > 1 @pytest.mark.asyncio - async def test_escape_closes_help(self, registered_demo) -> None: + async def test_escape_closes_help(self, registered_demo: Demo) -> None: """Test that Escape closes help screen.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -312,7 +322,7 @@ async def test_escape_closes_help(self, registered_demo) -> None: assert len(app.screen_stack) == 1 @pytest.mark.asyncio - async def test_q_closes_help(self, registered_demo) -> None: + async def test_q_closes_help(self, registered_demo: Demo) -> None: """Test that help screen can be dismissed programmatically.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -333,7 +343,7 @@ class TestSQLiterDemoAppKeyboardBindings: """Test keyboard bindings.""" @pytest.mark.asyncio - async def test_f5_runs_demo(self, registered_demo) -> None: + async def test_f5_runs_demo(self, registered_demo: Demo) -> None: """Test that F5 runs the demo.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -349,7 +359,7 @@ async def test_f5_runs_demo(self, registered_demo) -> None: assert output_display is not None @pytest.mark.asyncio - async def test_f8_clears_output(self, registered_demo) -> None: + async def test_f8_clears_output(self, registered_demo: Demo) -> None: """Test that F8 clears output.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -366,11 +376,11 @@ async def test_f8_clears_output(self, registered_demo) -> None: await pilot.pause() # Should be cleared - content = output_display.query_one("#output-content") + content = cast("Any", output_display.query_one("#output-content")) assert "Run a demo" in str(content.content) @pytest.mark.asyncio - async def test_vim_keys_work(self, registered_demo) -> None: + async def test_vim_keys_work(self, registered_demo: Demo) -> None: """Test that vim-style j/k keys work.""" app = SQLiterDemoApp() async with app.run_test() as pilot: @@ -394,7 +404,7 @@ class TestSQLiterDemoAppErrorHandling: @pytest.mark.asyncio async def test_run_demo_failure_shows_error( - self, reset_demo_registry + self, reset_demo_registry: None ) -> None: """Test that failed demo execution shows error message.""" @@ -426,7 +436,7 @@ def failing_execute() -> str: # Should show error output output_display = app.query_one("#output-display", OutputDisplay) - content = output_display.query_one("#output-content") + content = cast("Any", output_display.query_one("#output-content")) output_str = str(content.content).lower() # Error output should contain error information assert "exception" in output_str or "error" in output_str @@ -436,7 +446,7 @@ class TestSQLiterDemoAppEdgeCases: """Test edge cases and error handling.""" @pytest.mark.asyncio - async def test_empty_registry(self, reset_demo_registry) -> None: + async def test_empty_registry(self, reset_demo_registry: None) -> None: """Test app with empty demo registry.""" # Don't register any demos app = SQLiterDemoApp() @@ -446,7 +456,9 @@ async def test_empty_registry(self, reset_demo_registry) -> None: assert demo_list is not None @pytest.mark.asyncio - async def test_no_matches_exception_handling(self, registered_demo) -> None: + async def test_no_matches_exception_handling( + self, registered_demo: Demo + ) -> None: """Test graceful handling of NoMatches exception.""" app = SQLiterDemoApp() async with app.run_test() as _: @@ -456,7 +468,7 @@ async def test_no_matches_exception_handling(self, registered_demo) -> None: @pytest.mark.asyncio async def test_cursor_down_handles_missing_tree( - self, registered_demo, mocker + self, registered_demo: Demo, mocker: MockerFixture ) -> None: """Test cursor down handles missing tree gracefully.""" app = SQLiterDemoApp() @@ -468,7 +480,7 @@ async def test_cursor_down_handles_missing_tree( @pytest.mark.asyncio async def test_cursor_up_handles_missing_tree( - self, registered_demo, mocker + self, registered_demo: Demo, mocker: MockerFixture ) -> None: """Test cursor up handles missing tree gracefully.""" app = SQLiterDemoApp() @@ -479,14 +491,16 @@ async def test_cursor_up_handles_missing_tree( app.action_tree_cursor_up() @pytest.mark.asyncio - async def test_app_properties(self, registered_demo) -> None: + async def test_app_properties(self, registered_demo: Demo) -> None: """Test app properties.""" app = SQLiterDemoApp() assert app.TITLE == "SQLiter Interactive Demo" assert app.CSS_PATH == "styles/app.tcss" @pytest.mark.asyncio - async def test_multiple_demo_selections(self, reset_demo_registry) -> None: + async def test_multiple_demo_selections( + self, reset_demo_registry: None + ) -> None: """Test selecting multiple demos in sequence.""" demo1 = Demo( id="demo1", diff --git a/tests/tui/test_demo_registry.py b/tests/tui/test_demo_registry.py index 33ed061f..9d8f300c 100644 --- a/tests/tui/test_demo_registry.py +++ b/tests/tui/test_demo_registry.py @@ -11,7 +11,7 @@ class TestDemoRegistry: """Test the DemoRegistry class.""" - def test_register_category(self, reset_demo_registry) -> None: + def test_register_category(self, reset_demo_registry: None) -> None: """Test registering a single category.""" category = DemoCategory( id="test_category", @@ -25,7 +25,9 @@ def test_register_category(self, reset_demo_registry) -> None: assert len(categories) == 1 assert categories[0] == category - def test_register_multiple_categories(self, reset_demo_registry) -> None: + def test_register_multiple_categories( + self, reset_demo_registry: None + ) -> None: """Test registering multiple categories.""" cat1 = DemoCategory(id="cat1", title="Category 1") cat2 = DemoCategory(id="cat2", title="Category 2") @@ -39,7 +41,7 @@ def test_register_multiple_categories(self, reset_demo_registry) -> None: assert len(categories) == 3 assert categories == (cat1, cat2, cat3) - def test_get_demo_by_id(self, reset_demo_registry) -> None: + def test_get_demo_by_id(self, reset_demo_registry: None) -> None: """Test retrieving a demo by its ID.""" demo = Demo( id="test_demo", @@ -58,12 +60,12 @@ def test_get_demo_by_id(self, reset_demo_registry) -> None: assert retrieved.id == "test_demo" assert retrieved.title == "Test Demo" - def test_get_demo_not_found(self, reset_demo_registry) -> None: + def test_get_demo_not_found(self, reset_demo_registry: None) -> None: """Test retrieving a non-existent demo.""" result = DemoRegistry.get_demo("nonexistent") assert result is None - def test_get_demo_code_with_setup(self, reset_demo_registry) -> None: + def test_get_demo_code_with_setup(self, reset_demo_registry: None) -> None: """Test getting demo code including setup.""" demo = Demo( id="test_demo", @@ -83,7 +85,9 @@ def test_get_demo_code_with_setup(self, reset_demo_registry) -> None: assert "print('setup')" in code assert "print('main')" in code - def test_get_demo_code_without_setup(self, reset_demo_registry) -> None: + def test_get_demo_code_without_setup( + self, reset_demo_registry: None + ) -> None: """Test getting demo code without setup.""" demo = Demo( id="test_demo", @@ -100,12 +104,12 @@ def test_get_demo_code_without_setup(self, reset_demo_registry) -> None: code = DemoRegistry.get_demo_code("test_demo") assert code == "print('main')" - def test_get_demo_code_not_found(self, reset_demo_registry) -> None: + def test_get_demo_code_not_found(self, reset_demo_registry: None) -> None: """Test getting code for non-existent demo.""" code = DemoRegistry.get_demo_code("nonexistent") assert code == "" - def test_reset_registry(self, reset_demo_registry) -> None: + def test_reset_registry(self, reset_demo_registry: None) -> None: """Test resetting the registry.""" category = DemoCategory(id="test", title="Test") DemoRegistry.register_category(category) @@ -117,7 +121,7 @@ def test_reset_registry(self, reset_demo_registry) -> None: assert len(DemoRegistry.get_categories()) == 0 assert DemoRegistry.get_demo("test") is None - def test_demo_id_uniqueness(self, reset_demo_registry) -> None: + def test_demo_id_uniqueness(self, reset_demo_registry: None) -> None: """Test that duplicate demo IDs raise ValueError.""" demo1 = Demo( id="duplicate", @@ -145,7 +149,9 @@ def test_demo_id_uniqueness(self, reset_demo_registry) -> None: with pytest.raises(ValueError, match="Duplicate demo id: duplicate"): DemoRegistry.register_category(cat2) - def test_get_categories_returns_sequence(self, reset_demo_registry) -> None: + def test_get_categories_returns_sequence( + self, reset_demo_registry: None + ) -> None: """Test that get_categories returns a sequence.""" category = DemoCategory(id="test", title="Test") DemoRegistry.register_category(category) @@ -156,7 +162,9 @@ def test_get_categories_returns_sequence(self, reset_demo_registry) -> None: # Should support len() assert len(categories) == 1 - def test_category_with_multiple_demos(self, reset_demo_registry) -> None: + def test_category_with_multiple_demos( + self, reset_demo_registry: None + ) -> None: """Test a category with multiple demos.""" def make_demo(idx: int) -> Demo: @@ -184,7 +192,7 @@ def execute() -> str: assert retrieved is not None assert retrieved.id == f"demo{i}" - def test_demo_code_formatting(self, reset_demo_registry) -> None: + def test_demo_code_formatting(self, reset_demo_registry: None) -> None: """Test that demo code is formatted correctly.""" demo = Demo( id="test", diff --git a/tests/tui/test_init.py b/tests/tui/test_init.py index bbcff5de..c95da09c 100644 --- a/tests/tui/test_init.py +++ b/tests/tui/test_init.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest import sqliter.tui @@ -12,6 +14,9 @@ run, ) +if TYPE_CHECKING: + from pytest_mock import MockerFixture + if _TEXTUAL_AVAILABLE: from sqliter.tui.app import SQLiterDemoApp else: @@ -38,7 +43,9 @@ def test_get_app_with_textual(self) -> None: app = get_app() assert isinstance(app, SQLiterDemoApp) - def test_get_app_without_textual_raises_import_error(self, mocker) -> None: + def test_get_app_without_textual_raises_import_error( + self, mocker: MockerFixture + ) -> None: """Test get_app raises ImportError when textual is not available.""" # Mock _TEXTUAL_AVAILABLE to False mocker.patch("sqliter.tui._TEXTUAL_AVAILABLE", False) @@ -53,7 +60,7 @@ class TestRunFunction: """Test the run function.""" @pytest.mark.skipif(not _TEXTUAL_AVAILABLE, reason="textual not installed") - def test_run_calls_app_run(self, mocker) -> None: + def test_run_calls_app_run(self, mocker: MockerFixture) -> None: """Test that run() calls app.run().""" mock_run = mocker.patch("sqliter.tui.app.SQLiterDemoApp.run") run() diff --git a/tests/tui/test_main.py b/tests/tui/test_main.py index f1067c9a..acfc91c6 100644 --- a/tests/tui/test_main.py +++ b/tests/tui/test_main.py @@ -2,13 +2,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from sqliter.tui import __main__ +if TYPE_CHECKING: + from pytest_mock import MockerFixture + class TestMainModule: """Test the __main__ module execution.""" - def test_main_imports_run(self, mocker) -> None: + def test_main_imports_run(self, mocker: MockerFixture) -> None: """Test that __main__ imports run correctly.""" # The run function should be imported mocker.patch("sqliter.tui.run") diff --git a/tests/tui/unit/test_demos_others.py b/tests/tui/unit/test_demos_others.py index 82227ffd..2706ddfd 100644 --- a/tests/tui/unit/test_demos_others.py +++ b/tests/tui/unit/test_demos_others.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from sqliter.orm import ManyToMany @@ -14,6 +16,9 @@ transactions, ) +if TYPE_CHECKING: + from types import ModuleType + CATEGORY_MODULES = ( (caching, "caching"), (errors, "errors"), @@ -31,7 +36,7 @@ class TestGetCategories: ("module", "expected_id"), CATEGORY_MODULES, ) - def test_category_valid(self, module, expected_id) -> None: + def test_category_valid(self, module: ModuleType, expected_id: str) -> None: """Test that category is valid.""" category = module.get_category() assert category.id == expected_id @@ -41,7 +46,9 @@ def test_category_valid(self, module, expected_id) -> None: ("module", "expected_id"), CATEGORY_MODULES, ) - def test_all_demos_execute(self, module, expected_id) -> None: + def test_all_demos_execute( + self, module: ModuleType, expected_id: str + ) -> None: """Test that all demos execute.""" category = module.get_category() for demo in category.demos: diff --git a/tests/tui/widgets/test_demo_list.py b/tests/tui/widgets/test_demo_list.py index 647cde27..cf1b5bf8 100644 --- a/tests/tui/widgets/test_demo_list.py +++ b/tests/tui/widgets/test_demo_list.py @@ -21,7 +21,7 @@ class TestDemoList: @pytest.mark.asyncio async def test_demo_selection_posts_message( - self, reset_demo_registry + self, reset_demo_registry: None ) -> None: """Test that selecting a demo posts DemoSelected message.""" demo = Demo( @@ -76,7 +76,7 @@ def __init__( assert messages_received[0].id == "demo1" @pytest.mark.asyncio - async def test_tree_composition(self, reset_demo_registry) -> None: + async def test_tree_composition(self, reset_demo_registry: None) -> None: """Test that the widget composes a Tree.""" category = DemoCategory(id="test", title="Test Category", icon="๐Ÿงช") DemoRegistry.register_category(category) @@ -91,7 +91,7 @@ async def test_tree_composition(self, reset_demo_registry) -> None: assert tree.id == "demo-tree" @pytest.mark.asyncio - async def test_category_nodes(self, reset_demo_registry) -> None: + async def test_category_nodes(self, reset_demo_registry: None) -> None: """Test that categories are rendered as nodes.""" cat1 = DemoCategory(id="cat1", title="Category 1", icon="๐Ÿ“ฆ") cat2 = DemoCategory(id="cat2", title="Category 2", icon="๐Ÿ“") @@ -108,7 +108,7 @@ async def test_category_nodes(self, reset_demo_registry) -> None: assert len(tree.root.children) == 2 @pytest.mark.asyncio - async def test_demo_leaves(self, reset_demo_registry) -> None: + async def test_demo_leaves(self, reset_demo_registry: None) -> None: """Test that demos are rendered as leaf nodes.""" demo1 = Demo( id="demo1", @@ -143,7 +143,7 @@ async def test_demo_leaves(self, reset_demo_registry) -> None: assert len(cat_node.children) == 2 @pytest.mark.asyncio - async def test_node_expansion(self, reset_demo_registry) -> None: + async def test_node_expansion(self, reset_demo_registry: None) -> None: """Test that categories can be expanded/collapsed.""" category = DemoCategory(id="cat1", title="Category", expanded=False) demo = Demo( @@ -174,7 +174,9 @@ async def test_node_expansion(self, reset_demo_registry) -> None: assert cat_node.is_expanded @pytest.mark.asyncio - async def test_demo_selection_event(self, reset_demo_registry) -> None: + async def test_demo_selection_event( + self, reset_demo_registry: None + ) -> None: """Test that selecting a demo posts DemoSelected message.""" demo = Demo( id="demo1", @@ -204,7 +206,7 @@ async def test_demo_selection_event(self, reset_demo_registry) -> None: assert demo_node.data.id == "demo1" @pytest.mark.asyncio - async def test_empty_registry(self, reset_demo_registry) -> None: + async def test_empty_registry(self, reset_demo_registry: None) -> None: """Test that empty registry works.""" # Don't register any categories app: App[Any] = App() @@ -217,7 +219,9 @@ async def test_empty_registry(self, reset_demo_registry) -> None: assert len(tree.root.children) == 0 @pytest.mark.asyncio - async def test_integration_with_registry(self, reset_demo_registry) -> None: + async def test_integration_with_registry( + self, reset_demo_registry: None + ) -> None: """Test integration with DemoRegistry.""" # Register multiple categories with demos demos = [ @@ -254,7 +258,7 @@ async def test_integration_with_registry(self, reset_demo_registry) -> None: assert len(cat_node.children) == 3 @pytest.mark.asyncio - async def test_tree_root_hidden(self, reset_demo_registry) -> None: + async def test_tree_root_hidden(self, reset_demo_registry: None) -> None: """Test that the tree root is hidden.""" category = DemoCategory(id="test", title="Test") DemoRegistry.register_category(category) @@ -268,7 +272,9 @@ async def test_tree_root_hidden(self, reset_demo_registry) -> None: assert tree.show_root is False @pytest.mark.asyncio - async def test_demo_label_formatting(self, reset_demo_registry) -> None: + async def test_demo_label_formatting( + self, reset_demo_registry: None + ) -> None: """Test that demo labels are formatted correctly.""" demo = Demo( id="demo1",