diff --git a/docs/fields/field-types.md b/docs/fields/field-types.md index 3e13ff998..36cf996b8 100644 --- a/docs/fields/field-types.md +++ b/docs/fields/field-types.md @@ -19,11 +19,33 @@ Each of the `Fields` has assigned both `sqlalchemy` column class and python type regex: str = None,)` has a required `max_length` parameter. * Sqlalchemy column: `sqlalchemy.String` -* Type (used for pydantic): `str` +* Type (used for pydantic): `str`, or `typing.Literal[...]` when string `choices` are provided !!!tip For explanation of other parameters check [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation) documentation. +You can also constrain string values with `choices` while keeping a varchar column: + +```python +from typing import Literal + +import ormar + + +class Account(ormar.Model): + ormar_config = ... + + id: int = ormar.Integer(primary_key=True) + mode: Literal["user", "manager", "admin"] = ormar.String( + max_length=32, + choices=("user", "manager", "admin"), + ) +``` + +For the strongest static typing, keep the explicit field annotation. If you rely on +inference from the `choices` argument alone, type checkers are more likely to infer a +`str` type instead of the appropriate literal. + ### Text `Text()` has no required parameters. @@ -220,7 +242,6 @@ So which one to use depends on the backend you use and on the column/ data type * Sqlalchemy column: `sqlalchemy.Enum` * Type (used for pydantic): `type[Enum]` - [relations]: ../relations/index.md [queries]: ../queries.md [pydantic]: https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index f6b589480..2e1c00804 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -3,7 +3,7 @@ import uuid from enum import Enum as E from enum import EnumMeta -from typing import Any, Optional +from typing import Any, Literal, Optional import pydantic import sqlalchemy @@ -162,6 +162,8 @@ def __new__( # type: ignore # noqa CFQ002 regex: Optional[str] = None, **kwargs: Any ) -> Self: # type: ignore + choices = kwargs.get("choices") + nullable = kwargs.get("nullable") kwargs = { **kwargs, **{ @@ -170,6 +172,11 @@ def __new__( # type: ignore # noqa CFQ002 if k not in ["cls", "__class__", "kwargs"] }, } + overwrite_pydantic_type = cls._get_choices_pydantic_type(choices) + if overwrite_pydantic_type is not None: + if nullable is True: + overwrite_pydantic_type = Optional[overwrite_pydantic_type] + kwargs["overwrite_pydantic_type"] = overwrite_pydantic_type return super().__new__(cls, **kwargs) @classmethod @@ -198,6 +205,14 @@ def validate(cls, **kwargs: Any) -> None: "Parameter max_length is required for field String" ) + @staticmethod + def _get_choices_pydantic_type(choices: Any) -> Any: + if choices is None: + return None + if any(not isinstance(choice, str) for choice in choices): + raise ModelDefinitionError("String Field choices must be strings") + return Literal.__getitem__(tuple(choices)) + class Integer(ModelFieldFactory, int): """ diff --git a/ormar/fields/model_fields.pyi b/ormar/fields/model_fields.pyi index e642c14a6..362242168 100644 --- a/ormar/fields/model_fields.pyi +++ b/ormar/fields/model_fields.pyi @@ -7,12 +7,29 @@ from typing import Any, Literal, TypeVar, Union, overload from uuid import UUID as UuidType T = TypeVar("T", bound=EnumBase) +L = TypeVar("L", bound=str) @overload def Boolean(*, nullable: Literal[False] = False, **kwargs: Any) -> bool: ... @overload def Boolean(*, nullable: Literal[True], **kwargs: Any) -> bool | None: ... @overload +def String( + *, + max_length: int, + choices: list[L] | tuple[L, ...], + nullable: Literal[False] = False, + **kwargs: Any, +) -> L: ... +@overload +def String( + *, + max_length: int, + choices: list[L] | tuple[L, ...], + nullable: Literal[True], + **kwargs: Any, +) -> L | None: ... +@overload def String( *, max_length: int, nullable: Literal[False] = False, **kwargs: Any ) -> str: ... diff --git a/tests/test_fastapi/test_string_schema.py b/tests/test_fastapi/test_string_schema.py new file mode 100644 index 000000000..6bab211b5 --- /dev/null +++ b/tests/test_fastapi/test_string_schema.py @@ -0,0 +1,48 @@ +from typing import Literal + +import ormar + +from tests.lifespan import init_tests +from tests.settings import create_config + +base_ormar_config = create_config() + + +class StringChoicesExample(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="string_choices_example") + + id: int = ormar.Integer(primary_key=True) + mode: Literal["user", "manager", "admin"] = ormar.String( + max_length=32, choices=["user", "manager", "admin"] + ) + optional_mode: Literal["user", "manager", "admin"] | None = ormar.String( + max_length=32, + choices=["user", "manager", "admin"], + nullable=True, + ) + + +create_test_database = init_tests(base_ormar_config) + + +def test_string_choices_schema(): + schema = StringChoicesExample.model_json_schema() + + assert schema["properties"]["mode"] == { + "enum": ["user", "manager", "admin"], + "maxLength": 32, + "title": "Mode", + "type": "string", + } + assert schema["properties"]["optional_mode"] == { + "anyOf": [ + { + "enum": ["user", "manager", "admin"], + "maxLength": 32, + "type": "string", + }, + {"type": "null"}, + ], + "default": None, + "title": "Optional Mode", + } diff --git a/tests/test_model_definition/test_models.py b/tests/test_model_definition/test_models.py index d994e5cc7..f5b66fedf 100644 --- a/tests/test_model_definition/test_models.py +++ b/tests/test_model_definition/test_models.py @@ -4,14 +4,14 @@ import os import uuid from enum import Enum -from typing import Optional +from typing import Literal, Optional, get_args -import ormar import pydantic import pytest import sqlalchemy -from ormar.exceptions import ModelError, NoMatch, QueryDefinitionError +import ormar +from ormar.exceptions import ModelError, NoMatch, QueryDefinitionError from tests.lifespan import init_tests from tests.settings import create_config @@ -139,6 +139,20 @@ class NotNullableCountry(ormar.Model): name: CountryNameEnum = ormar.Enum(enum_class=CountryNameEnum, nullable=False) +class UserRole(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="user_roles") + + id: int = ormar.Integer(primary_key=True) + mode: Literal["user", "manager", "admin"] = ormar.String( + max_length=32, choices=["user", "manager", "admin"] + ) + optional_mode: Optional[Literal["user", "manager", "admin"]] = ormar.String( + max_length=32, + choices=["user", "manager", "admin"], + nullable=True, + ) + + create_test_database = init_tests(base_ormar_config) @@ -158,6 +172,20 @@ def test_wrong_field_name(): User(non_existing_pk=1) +def test_string_choices_field_definition(): + field = UserRole.ormar_config.model_fields["mode"] + + assert isinstance(field.column_type, sqlalchemy.String) + assert field.column_type.length == 32 + assert field.choices == ["user", "manager", "admin"] + assert get_args(field.__pydantic_type__) == ("user", "manager", "admin") + + +def test_string_field_rejects_non_string_choices(): + with pytest.raises(ormar.ModelDefinitionError): + ormar.String(max_length=8, choices=["user", 1]) # type: ignore[list-item] + + def test_model_pk(): user = User(pk=1) assert user.pk == 1 @@ -525,6 +553,39 @@ async def test_nullable_field_model_enum(): await NotNullableCountry(name=None).save() +@pytest.mark.asyncio +async def test_string_choices_field_validation(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + role = await UserRole.objects.create(mode="user", optional_mode=None) + assert role.mode == "user" + assert role.optional_mode is None + + role.optional_mode = "manager" + assert role.optional_mode == "manager" + + role.mode = "admin" + await role.update() + refreshed = await UserRole.objects.get(pk=role.pk) + assert refreshed.mode == "admin" + assert refreshed.optional_mode == "manager" + + with pytest.raises(ValueError): + UserRole(mode="guest") + + with pytest.raises(ValueError): + role.mode = "guest" + + with pytest.raises(ValueError): + await UserRole(mode="guest").save() + + with pytest.raises(ValueError): + await UserRole.objects.filter(pk=role.pk).update(mode="guest") + + with pytest.raises(ValueError): + role.optional_mode = "guest" + + @pytest.mark.asyncio async def test_start_and_end_filters(): async with base_ormar_config.database: