Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions docs/fields/field-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,33 @@ Each of the `Fields` has assigned both `sqlalchemy` column class and python type
regex: str = None,)` has a required `max_length` parameter.

* Sqlalchemy column: `sqlalchemy.String`
* Type (used for pydantic): `str`
* Type (used for pydantic): `str`, or `typing.Literal[...]` when string `choices` are provided

!!!tip
For explanation of other parameters check [pydantic](https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation) documentation.

You can also constrain string values with `choices` while keeping a varchar column:

```python
from typing import Literal

import ormar


class Account(ormar.Model):
ormar_config = ...

id: int = ormar.Integer(primary_key=True)
mode: Literal["user", "manager", "admin"] = ormar.String(
max_length=32,
choices=("user", "manager", "admin"),
)
```

For the strongest static typing, keep the explicit field annotation. If you rely on
inference from the `choices` argument alone, type checkers are more likely to infer a
`str` type instead of the appropriate literal.

### Text

`Text()` has no required parameters.
Expand Down Expand Up @@ -220,7 +242,6 @@ So which one to use depends on the backend you use and on the column/ data type
* Sqlalchemy column: `sqlalchemy.Enum`
* Type (used for pydantic): `type[Enum]`


[relations]: ../relations/index.md
[queries]: ../queries.md
[pydantic]: https://pydantic-docs.helpmanual.io/usage/schema/#field-customisation
Expand Down
17 changes: 16 additions & 1 deletion ormar/fields/model_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from enum import Enum as E
from enum import EnumMeta
from typing import Any, Optional
from typing import Any, Literal, Optional

import pydantic
import sqlalchemy
Expand Down Expand Up @@ -162,6 +162,8 @@ def __new__( # type: ignore # noqa CFQ002
regex: Optional[str] = None,
**kwargs: Any
) -> Self: # type: ignore
choices = kwargs.get("choices")
nullable = kwargs.get("nullable")
kwargs = {
**kwargs,
**{
Expand All @@ -170,6 +172,11 @@ def __new__( # type: ignore # noqa CFQ002
if k not in ["cls", "__class__", "kwargs"]
},
}
overwrite_pydantic_type = cls._get_choices_pydantic_type(choices)
if overwrite_pydantic_type is not None:
if nullable is True:
overwrite_pydantic_type = Optional[overwrite_pydantic_type]
kwargs["overwrite_pydantic_type"] = overwrite_pydantic_type
return super().__new__(cls, **kwargs)

@classmethod
Expand Down Expand Up @@ -198,6 +205,14 @@ def validate(cls, **kwargs: Any) -> None:
"Parameter max_length is required for field String"
)

@staticmethod
def _get_choices_pydantic_type(choices: Any) -> Any:
if choices is None:
return None
if any(not isinstance(choice, str) for choice in choices):
raise ModelDefinitionError("String Field choices must be strings")
return Literal.__getitem__(tuple(choices))


class Integer(ModelFieldFactory, int):
"""
Expand Down
17 changes: 17 additions & 0 deletions ormar/fields/model_fields.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,29 @@ from typing import Any, Literal, TypeVar, Union, overload
from uuid import UUID as UuidType

T = TypeVar("T", bound=EnumBase)
L = TypeVar("L", bound=str)

@overload
def Boolean(*, nullable: Literal[False] = False, **kwargs: Any) -> bool: ...
@overload
def Boolean(*, nullable: Literal[True], **kwargs: Any) -> bool | None: ...
@overload
def String(
*,
max_length: int,
choices: list[L] | tuple[L, ...],
nullable: Literal[False] = False,
**kwargs: Any,
) -> L: ...
@overload
def String(
*,
max_length: int,
choices: list[L] | tuple[L, ...],
nullable: Literal[True],
**kwargs: Any,
) -> L | None: ...
@overload
def String(
*, max_length: int, nullable: Literal[False] = False, **kwargs: Any
) -> str: ...
Expand Down
48 changes: 48 additions & 0 deletions tests/test_fastapi/test_string_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Literal

import ormar

from tests.lifespan import init_tests
from tests.settings import create_config

base_ormar_config = create_config()


class StringChoicesExample(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="string_choices_example")

id: int = ormar.Integer(primary_key=True)
mode: Literal["user", "manager", "admin"] = ormar.String(
max_length=32, choices=["user", "manager", "admin"]
)
optional_mode: Literal["user", "manager", "admin"] | None = ormar.String(
max_length=32,
choices=["user", "manager", "admin"],
nullable=True,
)


create_test_database = init_tests(base_ormar_config)


def test_string_choices_schema():
schema = StringChoicesExample.model_json_schema()

assert schema["properties"]["mode"] == {
"enum": ["user", "manager", "admin"],
"maxLength": 32,
"title": "Mode",
"type": "string",
}
assert schema["properties"]["optional_mode"] == {
"anyOf": [
{
"enum": ["user", "manager", "admin"],
"maxLength": 32,
"type": "string",
},
{"type": "null"},
],
"default": None,
"title": "Optional Mode",
}
67 changes: 64 additions & 3 deletions tests/test_model_definition/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import os
import uuid
from enum import Enum
from typing import Optional
from typing import Literal, Optional, get_args

import ormar
import pydantic
import pytest
import sqlalchemy
from ormar.exceptions import ModelError, NoMatch, QueryDefinitionError

import ormar
from ormar.exceptions import ModelError, NoMatch, QueryDefinitionError
from tests.lifespan import init_tests
from tests.settings import create_config

Expand Down Expand Up @@ -139,6 +139,20 @@ class NotNullableCountry(ormar.Model):
name: CountryNameEnum = ormar.Enum(enum_class=CountryNameEnum, nullable=False)


class UserRole(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="user_roles")

id: int = ormar.Integer(primary_key=True)
mode: Literal["user", "manager", "admin"] = ormar.String(
max_length=32, choices=["user", "manager", "admin"]
)
optional_mode: Optional[Literal["user", "manager", "admin"]] = ormar.String(
max_length=32,
choices=["user", "manager", "admin"],
nullable=True,
)


create_test_database = init_tests(base_ormar_config)


Expand All @@ -158,6 +172,20 @@ def test_wrong_field_name():
User(non_existing_pk=1)


def test_string_choices_field_definition():
field = UserRole.ormar_config.model_fields["mode"]

assert isinstance(field.column_type, sqlalchemy.String)
assert field.column_type.length == 32
assert field.choices == ["user", "manager", "admin"]
assert get_args(field.__pydantic_type__) == ("user", "manager", "admin")


def test_string_field_rejects_non_string_choices():
with pytest.raises(ormar.ModelDefinitionError):
ormar.String(max_length=8, choices=["user", 1]) # type: ignore[list-item]


def test_model_pk():
user = User(pk=1)
assert user.pk == 1
Expand Down Expand Up @@ -525,6 +553,39 @@ async def test_nullable_field_model_enum():
await NotNullableCountry(name=None).save()


@pytest.mark.asyncio
async def test_string_choices_field_validation():
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
role = await UserRole.objects.create(mode="user", optional_mode=None)
assert role.mode == "user"
assert role.optional_mode is None

role.optional_mode = "manager"
assert role.optional_mode == "manager"

role.mode = "admin"
await role.update()
refreshed = await UserRole.objects.get(pk=role.pk)
assert refreshed.mode == "admin"
assert refreshed.optional_mode == "manager"

with pytest.raises(ValueError):
UserRole(mode="guest")

with pytest.raises(ValueError):
role.mode = "guest"

with pytest.raises(ValueError):
await UserRole(mode="guest").save()

with pytest.raises(ValueError):
await UserRole.objects.filter(pk=role.pk).update(mode="guest")

with pytest.raises(ValueError):
role.optional_mode = "guest"


@pytest.mark.asyncio
async def test_start_and_end_filters():
async with base_ormar_config.database:
Expand Down