diff --git a/src/django_enum/fields.py b/src/django_enum/fields.py index 5e86d95..adba03f 100644 --- a/src/django_enum/fields.py +++ b/src/django_enum/fields.py @@ -464,8 +464,15 @@ def __init__( self._strict_ = strict if enum else False self._coerce_ = coerce if enum else False self._constrained_ = constrained if constrained is not None else strict + self._choices_override_ = kwargs.get("choices") if self.enum is not None: kwargs.setdefault("choices", choices(enum)) + + if django_version < (5, 0) and "choices" in kwargs: + from django_enum.utils import normalize_choices + + kwargs["choices"] = normalize_choices(kwargs["choices"]) + super().__init__( null=kwargs.pop("null", False) or None in values(self.enum), **kwargs ) @@ -556,7 +563,13 @@ def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: """ name, path, args, kwargs = super().deconstruct() if self.enum is not None: - kwargs["choices"] = choices(self.enum) + # if choices are overridden, super().deconstruct() might have + # normalized them. We want to preserve the original format if + # possible, but only if it's not the default enum choices. + if self._choices_override_ is not None: + kwargs["choices"] = self._choices_override_ + elif "choices" not in kwargs: + kwargs["choices"] = choices(self.enum) if "db_default" in kwargs: try: @@ -726,15 +739,30 @@ def get_choices( limit_choices_to=None, ordering=(), ): - return [ - (getattr(choice, "value", choice), label) - for choice, label in super().get_choices( + def _coerce(choice): + return getattr(choice, "value", choice) + + def _recursive_coerce(choices_list): + coerced = [] + for item in choices_list: + if isinstance(item, (list, tuple)) and len(item) == 2: + choice, label = item + if isinstance(label, (list, tuple)): + coerced.append((_coerce(choice), _recursive_coerce(label))) + else: + coerced.append((_coerce(choice), label)) + else: + coerced.append((_coerce(item), item)) + return coerced + + return _recursive_coerce( + super().get_choices( include_blank=include_blank, blank_choice=list(blank_choice), limit_choices_to=limit_choices_to, ordering=ordering, ) - ] + ) @staticmethod def constraint_name( diff --git a/src/django_enum/forms.py b/src/django_enum/forms.py index 94b54c7..a855ad8 100644 --- a/src/django_enum/forms.py +++ b/src/django_enum/forms.py @@ -83,18 +83,41 @@ class NonStrictMixin: choices: _SelectChoices - def render(self, *args, **kwargs): + def render(self, name, value, attrs=None, renderer=None): """ Before rendering if we're a non-strict field and our value is not one of our choices, we add it as an option. """ - value: t.Any = getattr(kwargs.get("value"), "value", kwargs.get("value")) - if value not in EnumChoiceField.empty_values and value not in ( - choice[0] for choice in self.choices - ): - self.choices = list(self.choices) + [(value, str(value))] - return super().render(*args, **kwargs) # type: ignore[misc] + val: t.Any = getattr(value, "value", value) + + def _has_value(choices_to_search): + if isinstance(choices_to_search, dict): + for k, v in choices_to_search.items(): + if isinstance(v, (dict, list, tuple)): + if _has_value(v): + return True + elif k == val: + return True + return False + for item in choices_to_search: + if isinstance(item, (list, tuple)) and len(item) == 2: + choice, label = item + if isinstance(label, (list, tuple)): + if _has_value(label): + return True + elif choice == val: + return True + elif item == val: + return True + return False + + if val not in EnumChoiceField.empty_values and not _has_value(self.choices): + if isinstance(self.choices, dict): + self.choices = list(self.choices.items()) + [(val, str(val))] + else: + self.choices = list(self.choices) + [(val, str(val))] + return super().render(name, value, attrs=attrs, renderer=renderer) # type: ignore[misc] class NonStrictFlagMixin: @@ -106,22 +129,38 @@ class NonStrictFlagMixin: choices: _SelectChoices - def render(self, *args, **kwargs): + def render(self, name, value, attrs=None, renderer=None): """ Before rendering if we're a non-strict flag field and bits are set that are not part of our flag enumeration we add them as (integer value, bit index) to our (value, label) choice list. """ - raw_choices = zip( - get_set_values(kwargs.get("value")), get_set_bits(kwargs.get("value")) - ) + raw_choices = zip(get_set_values(value), get_set_bits(value)) self.choices = list(self.choices) - choice_values = set(choice[0] for choice in self.choices) - for value, label in raw_choices: - if value not in choice_values: - self.choices.append((value, label)) - return super().render(*args, **kwargs) # type: ignore[misc] + + def _get_values(choices_to_search): + if isinstance(choices_to_search, dict): + for k, v in choices_to_search.items(): + if isinstance(v, (dict, list, tuple)): + yield from _get_values(v) + else: + yield k + return + for choice, label in choices_to_search: + if isinstance(label, (list, tuple)): + yield from _get_values(label) + else: + yield choice + + choice_values = set(_get_values(self.choices)) + for v, label in raw_choices: + if v not in choice_values: + if isinstance(self.choices, dict): + self.choices = list(self.choices.items()) + [(v, label)] + else: + self.choices.append((v, label)) + return super().render(name, value, attrs=attrs, renderer=renderer) # type: ignore[misc] class NonStrictSelect(NonStrictMixin, Select): @@ -353,6 +392,11 @@ def __init__( self.empty_values = empty_values self._empty_values_overridden_ = True + from django_enum.utils import django_version, normalize_choices + + if django_version < (5, 0) and choices: + choices = normalize_choices(choices) + super().__init__( choices=choices or getattr(self.enum, "choices", choices), coerce=coerce or self.default_coerce, diff --git a/src/django_enum/utils.py b/src/django_enum/utils.py index 485b7d2..c9b29c1 100644 --- a/src/django_enum/utils.py +++ b/src/django_enum/utils.py @@ -13,6 +13,8 @@ get_args, ) +from django import VERSION as django_version + __all__ = [ "choices", "names", @@ -26,6 +28,8 @@ "get_set_bits", "decompose", "members", + "normalize_choices", + "django_version", ] @@ -323,3 +327,22 @@ def members(enum: type[E], aliases: bool = True) -> Generator[E, None, None]: else: for name in enum._member_names_: yield enum[name] # type: ignore[misc] + + +def normalize_choices(choices): + """ + Standardize choices for Django < 5.0. + In Django 5.0, choices can be a dict or a callable. + """ + if callable(choices): + choices = choices() + if isinstance(choices, dict): + return [ + ( + (key, normalize_choices(value)) + if isinstance(value, dict) + else (key, value) + ) + for key, value in choices.items() + ] + return choices diff --git a/tests/test_django5_overrides.py b/tests/test_django5_overrides.py new file mode 100644 index 0000000..9cbf1e3 --- /dev/null +++ b/tests/test_django5_overrides.py @@ -0,0 +1,161 @@ +import pytest +from django.db import models +from django_enum import EnumField +from django.db.models import IntegerChoices +from django import VERSION as django_version + + +class MyEnum(IntegerChoices): + VAL1 = 1, "Value 1" + VAL2 = 2, "Value 2" + + +class GroupedEnum(IntegerChoices): + V1 = 1, "One" + V2 = 2, "Two" + V3 = 3, "Three" + + +def get_choices_callable(): + return [(1, "Callable 1"), (2, "Callable 2")] + + +class EnumOverrideModel(models.Model): + # Dict choices (Django 5.0+) + dict_field = EnumField(MyEnum, choices={1: "Dict 1", 2: "Dict 2"}) + + # Callable choices (Django 5.0+) + callable_field = EnumField(MyEnum, choices=get_choices_callable) + + # Grouped choices + grouped_field = EnumField( + GroupedEnum, + choices=[ + ("Group A", [(1, "One"), (2, "Two")]), + ("Group B", [(3, "Three")]), + ], + ) + + # Nested dict choices (Django 5.0+) + nested_dict_field = EnumField( + MyEnum, + choices={ + "Audio": { + 1: "Vinyl", + 2: "CD", + }, + "unknown": "Unknown", + }, + ) + + class Meta: + abstract = True + app_label = "tests" + + +def test_deconstruct_preserves_overrides(): + """Verify that deconstruct() preserves dictionary and callable overrides.""" + field = EnumOverrideModel._meta.get_field("dict_field") + name, path, args, kwargs = field.deconstruct() + assert kwargs["choices"] == {1: "Dict 1", 2: "Dict 2"} + + field = EnumOverrideModel._meta.get_field("callable_field") + name, path, args, kwargs = field.deconstruct() + assert kwargs["choices"] == get_choices_callable + + field = EnumOverrideModel._meta.get_field("nested_dict_field") + name, path, args, kwargs = field.deconstruct() + assert kwargs["choices"] == { + "Audio": {1: "Vinyl", 2: "CD"}, + "unknown": "Unknown", + } + + +def test_get_choices_handles_dict(): + """Verify that get_choices() handles dictionary choices correctly.""" + field = EnumOverrideModel._meta.get_field("dict_field") + # Django converts dict to list of tuples internally in super().get_choices() + choices = field.get_choices(include_blank=False) + assert choices == [(1, "Dict 1"), (2, "Dict 2")] + + +def test_get_choices_handles_nested_dict(): + """Verify that get_choices() handles nested dictionary choices correctly.""" + field = EnumOverrideModel._meta.get_field("nested_dict_field") + choices = field.get_choices(include_blank=False) + # Normalized by Django + Coerced by EnumField + assert choices == [ + ("Audio", [(1, "Vinyl"), (2, "CD")]), + ("unknown", "Unknown"), + ] + + +def test_get_choices_handles_recursion(): + """Verify that get_choices() handles grouped choices correctly.""" + field = EnumOverrideModel._meta.get_field("grouped_field") + choices = field.get_choices(include_blank=False) + assert choices == [ + ("Group A", [(1, "One"), (2, "Two")]), + ("Group B", [(3, "Three")]), + ] + + +def test_validation_with_overridden_choices(): + """Verify that validation uses the overridden labels or values.""" + field = EnumOverrideModel._meta.get_field("dict_field") + # Validation in Django usually checks if value is in choices + # EnumField also tries to coerce. + # 1 is a valid value for MyEnum and also in the dict + field.validate(1, None) + + # 3 is NOT in MyEnum and NOT in the dict + with pytest.raises(Exception): # Django raises ValidationError + field.validate(3, None) + + +from django import forms as django_forms +from django_enum.forms import EnumChoiceField + + +def test_form_field_with_nested_dict(): + """Verify that EnumChoiceField handles nested dictionary choices (non-strict).""" + + class NonStrictForm(django_forms.Form): + field = EnumChoiceField( + MyEnum, + choices={ + "Audio": {1: "Vinyl", 2: "CD"}, + "unknown": "Unknown", + }, + strict=False, + ) + + form = NonStrictForm(data={"field": 1}) + assert form.is_valid() + + # 3 is not in Enum and not in dict + form = NonStrictForm(data={"field": 3}) + assert form.is_valid() # Non-strict allows it + + # Rendering should include the added choice if it's not in choices + # This verifies NonStrictMixin.render() + widget = form.fields["field"].widget + # Simulate a value not in choices + rendered = widget.render("field", 3, attrs={"id": "id_field"}) + assert 'value="3"' in rendered + assert ">3<" in rendered + + +def test_default_choices_still_work(): + """Verify that if no choices are provided, defaults from enum are used.""" + + class DefaultModel(models.Model): + field = EnumField(MyEnum) + + class Meta: + abstract = True + app_label = "tests" + + field = DefaultModel._meta.get_field("field") + name, path, args, kwargs = field.deconstruct() + assert kwargs["choices"] == [(1, "Value 1"), (2, "Value 2")]