Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/django_enum/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,15 @@ def __init__(
self._strict_ = strict if enum else False
self._coerce_ = coerce if enum else False
self._constrained_ = constrained if constrained is not None else strict
self._choices_override_ = kwargs.get("choices")
if self.enum is not None:
kwargs.setdefault("choices", choices(enum))

if django_version < (5, 0) and "choices" in kwargs:
from django_enum.utils import normalize_choices

kwargs["choices"] = normalize_choices(kwargs["choices"])

super().__init__(
null=kwargs.pop("null", False) or None in values(self.enum), **kwargs
)
Expand Down Expand Up @@ -556,7 +563,13 @@ def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
"""
name, path, args, kwargs = super().deconstruct()
if self.enum is not None:
kwargs["choices"] = choices(self.enum)
# if choices are overridden, super().deconstruct() might have
# normalized them. We want to preserve the original format if
# possible, but only if it's not the default enum choices.
if self._choices_override_ is not None:
kwargs["choices"] = self._choices_override_
elif "choices" not in kwargs:
kwargs["choices"] = choices(self.enum)

if "db_default" in kwargs:
try:
Expand Down Expand Up @@ -726,15 +739,30 @@ def get_choices(
limit_choices_to=None,
ordering=(),
):
return [
(getattr(choice, "value", choice), label)
for choice, label in super().get_choices(
def _coerce(choice):
return getattr(choice, "value", choice)

def _recursive_coerce(choices_list):
coerced = []
for item in choices_list:
if isinstance(item, (list, tuple)) and len(item) == 2:
choice, label = item
if isinstance(label, (list, tuple)):
coerced.append((_coerce(choice), _recursive_coerce(label)))
else:
coerced.append((_coerce(choice), label))
else:
coerced.append((_coerce(item), item))
return coerced

return _recursive_coerce(
super().get_choices(
include_blank=include_blank,
blank_choice=list(blank_choice),
limit_choices_to=limit_choices_to,
ordering=ordering,
)
]
)

@staticmethod
def constraint_name(
Expand Down
76 changes: 60 additions & 16 deletions src/django_enum/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,41 @@ class NonStrictMixin:

choices: _SelectChoices

def render(self, *args, **kwargs):
def render(self, name, value, attrs=None, renderer=None):
"""
Before rendering if we're a non-strict field and our value is not
one of our choices, we add it as an option.
"""

value: t.Any = getattr(kwargs.get("value"), "value", kwargs.get("value"))
if value not in EnumChoiceField.empty_values and value not in (
choice[0] for choice in self.choices
):
self.choices = list(self.choices) + [(value, str(value))]
return super().render(*args, **kwargs) # type: ignore[misc]
val: t.Any = getattr(value, "value", value)

def _has_value(choices_to_search):
if isinstance(choices_to_search, dict):
for k, v in choices_to_search.items():
if isinstance(v, (dict, list, tuple)):
if _has_value(v):
return True
elif k == val:
return True
return False
for item in choices_to_search:
if isinstance(item, (list, tuple)) and len(item) == 2:
choice, label = item
if isinstance(label, (list, tuple)):
if _has_value(label):
return True
elif choice == val:
return True
elif item == val:
return True
return False

if val not in EnumChoiceField.empty_values and not _has_value(self.choices):
if isinstance(self.choices, dict):
self.choices = list(self.choices.items()) + [(val, str(val))]
else:
self.choices = list(self.choices) + [(val, str(val))]
return super().render(name, value, attrs=attrs, renderer=renderer) # type: ignore[misc]


class NonStrictFlagMixin:
Expand All @@ -106,22 +129,38 @@ class NonStrictFlagMixin:

choices: _SelectChoices

def render(self, *args, **kwargs):
def render(self, name, value, attrs=None, renderer=None):
"""
Before rendering if we're a non-strict flag field and bits are set that are
not part of our flag enumeration we add them as (integer value, bit index)
to our (value, label) choice list.
"""

raw_choices = zip(
get_set_values(kwargs.get("value")), get_set_bits(kwargs.get("value"))
)
raw_choices = zip(get_set_values(value), get_set_bits(value))
self.choices = list(self.choices)
choice_values = set(choice[0] for choice in self.choices)
for value, label in raw_choices:
if value not in choice_values:
self.choices.append((value, label))
return super().render(*args, **kwargs) # type: ignore[misc]

Comment on lines +139 to +141
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NonStrictFlagMixin.render() converts self.choices with list(self.choices). If self.choices is a dict (allowed on Django 5+), this becomes a list of keys, which will break the subsequent (choice, label) unpacking in _get_values() and also makes the later isinstance(self.choices, dict) branch unreachable. Consider preserving the original dict (or converting dicts via .items()) before iterating/unpacking so dict-based choices work correctly.

Copilot uses AI. Check for mistakes.
def _get_values(choices_to_search):
if isinstance(choices_to_search, dict):
for k, v in choices_to_search.items():
if isinstance(v, (dict, list, tuple)):
yield from _get_values(v)
else:
yield k
return
for choice, label in choices_to_search:
if isinstance(label, (list, tuple)):
yield from _get_values(label)
else:
yield choice

choice_values = set(_get_values(self.choices))
for v, label in raw_choices:
if v not in choice_values:
if isinstance(self.choices, dict):
self.choices = list(self.choices.items()) + [(v, label)]
else:
self.choices.append((v, label))
return super().render(name, value, attrs=attrs, renderer=renderer) # type: ignore[misc]


class NonStrictSelect(NonStrictMixin, Select):
Expand Down Expand Up @@ -353,6 +392,11 @@ def __init__(
self.empty_values = empty_values
self._empty_values_overridden_ = True

from django_enum.utils import django_version, normalize_choices

if django_version < (5, 0) and choices:
choices = normalize_choices(choices)

super().__init__(
choices=choices or getattr(self.enum, "choices", choices),
coerce=coerce or self.default_coerce,
Expand Down
23 changes: 23 additions & 0 deletions src/django_enum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
get_args,
)

from django import VERSION as django_version

__all__ = [
"choices",
"names",
Expand All @@ -26,6 +28,8 @@
"get_set_bits",
"decompose",
"members",
"normalize_choices",
"django_version",
]


Expand Down Expand Up @@ -323,3 +327,22 @@ def members(enum: type[E], aliases: bool = True) -> Generator[E, None, None]:
else:
for name in enum._member_names_:
yield enum[name] # type: ignore[misc]


def normalize_choices(choices):
"""
Standardize choices for Django < 5.0.
In Django 5.0, choices can be a dict or a callable.
"""
if callable(choices):
choices = choices()
if isinstance(choices, dict):
return [
(
(key, normalize_choices(value))
if isinstance(value, dict)
else (key, value)
)
for key, value in choices.items()
]
return choices
161 changes: 161 additions & 0 deletions tests/test_django5_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import pytest
from django.db import models
from django_enum import EnumField
from django.db.models import IntegerChoices
from django import VERSION as django_version
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

django_version is imported but never used in this test module. Either remove the import, or use it (e.g., for a module-level skip/conditional assertions) to avoid dead code.

Suggested change
from django import VERSION as django_version

Copilot uses AI. Check for mistakes.


class MyEnum(IntegerChoices):
VAL1 = 1, "Value 1"
VAL2 = 2, "Value 2"


class GroupedEnum(IntegerChoices):
V1 = 1, "One"
V2 = 2, "Two"
V3 = 3, "Three"


def get_choices_callable():
return [(1, "Callable 1"), (2, "Callable 2")]


class EnumOverrideModel(models.Model):
# Dict choices (Django 5.0+)
dict_field = EnumField(MyEnum, choices={1: "Dict 1", 2: "Dict 2"})

# Callable choices (Django 5.0+)
callable_field = EnumField(MyEnum, choices=get_choices_callable)

# Grouped choices
grouped_field = EnumField(
GroupedEnum,
choices=[
("Group A", [(1, "One"), (2, "Two")]),
("Group B", [(3, "Three")]),
],
)

# Nested dict choices (Django 5.0+)
nested_dict_field = EnumField(
MyEnum,
choices={
"Audio": {
1: "Vinyl",
2: "CD",
},
"unknown": "Unknown",
},
)

class Meta:
abstract = True
app_label = "tests"


def test_deconstruct_preserves_overrides():
"""Verify that deconstruct() preserves dictionary and callable overrides."""
field = EnumOverrideModel._meta.get_field("dict_field")
name, path, args, kwargs = field.deconstruct()
assert kwargs["choices"] == {1: "Dict 1", 2: "Dict 2"}

field = EnumOverrideModel._meta.get_field("callable_field")
name, path, args, kwargs = field.deconstruct()
assert kwargs["choices"] == get_choices_callable

field = EnumOverrideModel._meta.get_field("nested_dict_field")
name, path, args, kwargs = field.deconstruct()
assert kwargs["choices"] == {
"Audio": {1: "Vinyl", 2: "CD"},
"unknown": "Unknown",
}


def test_get_choices_handles_dict():
"""Verify that get_choices() handles dictionary choices correctly."""
field = EnumOverrideModel._meta.get_field("dict_field")
# Django converts dict to list of tuples internally in super().get_choices()
choices = field.get_choices(include_blank=False)
assert choices == [(1, "Dict 1"), (2, "Dict 2")]


def test_get_choices_handles_nested_dict():
"""Verify that get_choices() handles nested dictionary choices correctly."""
field = EnumOverrideModel._meta.get_field("nested_dict_field")
choices = field.get_choices(include_blank=False)
# Normalized by Django + Coerced by EnumField
assert choices == [
("Audio", [(1, "Vinyl"), (2, "CD")]),
("unknown", "Unknown"),
]


def test_get_choices_handles_recursion():
"""Verify that get_choices() handles grouped choices correctly."""
field = EnumOverrideModel._meta.get_field("grouped_field")
choices = field.get_choices(include_blank=False)
assert choices == [
("Group A", [(1, "One"), (2, "Two")]),
("Group B", [(3, "Three")]),
]


def test_validation_with_overridden_choices():
"""Verify that validation uses the overridden labels or values."""
field = EnumOverrideModel._meta.get_field("dict_field")
# Validation in Django usually checks if value is in choices
# EnumField also tries to coerce.
# 1 is a valid value for MyEnum and also in the dict
field.validate(1, None)

# 3 is NOT in MyEnum and NOT in the dict
with pytest.raises(Exception): # Django raises ValidationError
field.validate(3, None)
Comment on lines +112 to +113
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test expects a ValidationError, but it currently uses pytest.raises(Exception), which can mask unrelated errors and make failures harder to diagnose. Prefer asserting the specific django.core.exceptions.ValidationError (as other tests in this suite do).

Copilot uses AI. Check for mistakes.


from django import forms as django_forms
from django_enum.forms import EnumChoiceField


def test_form_field_with_nested_dict():
"""Verify that EnumChoiceField handles nested dictionary choices (non-strict)."""

class NonStrictForm(django_forms.Form):
field = EnumChoiceField(
MyEnum,
choices={
"Audio": {1: "Vinyl", 2: "CD"},
"unknown": "Unknown",
},
strict=False,
)

form = NonStrictForm(data={"field": 1})
assert form.is_valid()

# 3 is not in Enum and not in dict
form = NonStrictForm(data={"field": 3})
assert form.is_valid() # Non-strict allows it

# Rendering should include the added choice if it's not in choices
# This verifies NonStrictMixin.render()
widget = form.fields["field"].widget
# Simulate a value not in choices
rendered = widget.render("field", 3, attrs={"id": "id_field"})
assert 'value="3"' in rendered
assert ">3<" in rendered


def test_default_choices_still_work():
"""Verify that if no choices are provided, defaults from enum are used."""

class DefaultModel(models.Model):
field = EnumField(MyEnum)

class Meta:
abstract = True
app_label = "tests"

field = DefaultModel._meta.get_field("field")
name, path, args, kwargs = field.deconstruct()
assert kwargs["choices"] == [(1, "Value 1"), (2, "Value 2")]
Loading