diff --git a/netbox_custom_objects/field_types.py b/netbox_custom_objects/field_types.py index 75501f92..ebad4e5a 100644 --- a/netbox_custom_objects/field_types.py +++ b/netbox_custom_objects/field_types.py @@ -15,16 +15,22 @@ from extras.choices import CustomFieldTypeChoices, CustomFieldUIEditableChoices from utilities.api import get_serializer_for_model from utilities.forms.fields import ( - CSVChoiceField, CSVModelChoiceField, - CSVModelMultipleChoiceField, CSVMultipleChoiceField, - DynamicChoiceField, DynamicModelChoiceField, + CSVChoiceField, + CSVModelChoiceField, + CSVModelMultipleChoiceField, + CSVMultipleChoiceField, + DynamicChoiceField, + DynamicModelChoiceField, DynamicModelMultipleChoiceField, - DynamicMultipleChoiceField, JSONField, + DynamicMultipleChoiceField, + JSONField, LaxURLField, ) from utilities.forms.utils import add_blank_choice from utilities.forms.widgets import ( - APISelect, APISelectMultiple, DatePicker, + APISelect, + APISelectMultiple, + DatePicker, DateTimePicker, ) from utilities.templatetags.builtins.filters import linkify, render_markdown @@ -51,12 +57,13 @@ def __init__(self, to_model_name, *args, **kwargs): def contribute_to_class(self, cls, name, **kwargs): super().contribute_to_class(cls, name, **kwargs) # Mark this field for later resolution - setattr(cls, f"_resolve_{name}_model", self._resolve_model) + setattr(cls, f'_resolve_{name}_model', self._resolve_model) def _resolve_model(self, model): """Resolve the lazy reference to the actual model class.""" # Get the actual model class from the app registry from django.apps import apps + actual_model = apps.get_model(self._to_model_name) # Update the field's references self.remote_field.model = actual_model @@ -64,7 +71,6 @@ def _resolve_model(self, model): class FieldType: - def get_display_value(self, instance, field_name): """ This value is used as the object title in the Custom Object detail view. @@ -88,8 +94,7 @@ def _safe_kwargs(self, **kwargs): Create a safe kwargs dict that can be passed to Django field constructors. This method automatically filters out any custom parameters. """ - return {k: v for k, v in kwargs.items() - if not k.startswith('_') and k != 'generating_models'} + return {k: v for k, v in kwargs.items() if not k.startswith('_') and k != 'generating_models'} def get_annotated_form_field(self, field, enforce_visibility=True, **kwargs): form_field = self.get_form_field(field, **kwargs) @@ -119,10 +124,9 @@ def create_m2m_table(self, instance, model, field_name): ... class TextFieldType(FieldType): - def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.CharField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): @@ -138,9 +142,7 @@ def get_form_field(self, field, **kwargs): ), ) ] - return forms.CharField( - required=field.required, initial=field.default, validators=validators - ) + return forms.CharField(required=field.required, initial=field.default, validators=validators) def get_filterform_field(self, field, **kwargs): return forms.CharField( @@ -153,7 +155,7 @@ def get_filterform_field(self, field, **kwargs): class LongTextFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.TextField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): @@ -182,11 +184,10 @@ def render_table_column(self, value): class IntegerFieldType(FieldType): - def get_model_field(self, field, **kwargs): # TODO: handle all args for IntegerField field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.IntegerField(null=True, blank=True, **field_kwargs) def get_filterform_field(self, field, **kwargs): @@ -207,14 +208,8 @@ def get_form_field(self, field, **kwargs): class DecimalFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) - return models.DecimalField( - null=True, - blank=True, - max_digits=8, - decimal_places=2, - **field_kwargs - ) + field_kwargs.update({'default': field.default, 'unique': field.unique}) + return models.DecimalField(null=True, blank=True, max_digits=8, decimal_places=4, **field_kwargs) def get_form_field(self, field, **kwargs): return forms.DecimalField( @@ -226,11 +221,21 @@ def get_form_field(self, field, **kwargs): max_value=field.validation_maximum, ) + def get_filterform_field(self, field, **kwargs): + return forms.DecimalField( + label=field, + required=False, + max_digits=12, + decimal_places=4, + min_value=field.validation_minimum, + max_value=field.validation_maximum, + ) + class BooleanFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.BooleanField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): @@ -252,43 +257,37 @@ def get_table_column_field(self, field, **kwargs): class DateFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.DateField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): - return forms.DateField( - required=field.required, initial=field.default, widget=DatePicker() - ) + return forms.DateField(required=field.required, initial=field.default, widget=DatePicker()) class DateTimeFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.DateTimeField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): - return forms.DateTimeField( - required=field.required, initial=field.default, widget=DateTimePicker() - ) + return forms.DateTimeField(required=field.required, initial=field.default, widget=DateTimePicker()) class URLFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.URLField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): - return LaxURLField( - assume_scheme="https", required=field.required, initial=field.default - ) + return LaxURLField(assume_scheme='https', required=field.required, initial=field.default) class JSONFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return models.JSONField(null=True, blank=True, **field_kwargs) def get_form_field(self, field, **kwargs): @@ -301,14 +300,8 @@ def get_form_field(self, field, **kwargs): class SelectFieldType(FieldType): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) - return models.CharField( - max_length=100, - choices=field.choices, - null=True, - blank=True, - **field_kwargs - ) + field_kwargs.update({'default': field.default, 'unique': field.unique}) + return models.CharField(max_length=100, choices=field.choices, null=True, blank=True, **field_kwargs) def get_form_field(self, field, for_csv_import=False, **kwargs): choices = field.choice_set.choices @@ -324,9 +317,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): if for_csv_import: field_class = CSVChoiceField - return field_class( - choices=choices, required=field.required, initial=initial - ) + return field_class(choices=choices, required=field.required, initial=initial) else: field_class = DynamicChoiceField widget_class = APISelect @@ -334,9 +325,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): choices=choices, required=field.required, initial=initial, - widget=widget_class( - api_url=f"/api/extras/custom-field-choice-sets/{field.choice_set.pk}/choices/" - ), + widget=widget_class(api_url=f'/api/extras/custom-field-choice-sets/{field.choice_set.pk}/choices/'), ) @@ -346,12 +335,9 @@ def get_display_value(self, instance, field_name): def get_model_field(self, field, **kwargs): field_kwargs = self._safe_kwargs(**kwargs) - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) return ArrayField( - base_field=models.CharField(max_length=50, choices=field.choices), - null=True, - blank=True, - **field_kwargs + base_field=models.CharField(max_length=50, choices=field.choices), null=True, blank=True, **field_kwargs ) def get_form_field(self, field, for_csv_import=False, **kwargs): @@ -368,9 +354,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): if for_csv_import: field_class = CSVMultipleChoiceField - return field_class( - choices=choices, required=field.required, initial=initial - ) + return field_class(choices=choices, required=field.required, initial=initial) else: field_class = DynamicMultipleChoiceField widget_class = APISelectMultiple @@ -378,9 +362,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): choices=choices, required=field.required, initial=initial, - widget=widget_class( - api_url=f"/api/extras/custom-field-choice-sets/{field.choice_set.pk}/choices/" - ), + widget=widget_class(api_url=f'/api/extras/custom-field-choice-sets/{field.choice_set.pk}/choices/'), ) # TODO: Implement this @@ -393,38 +375,61 @@ def render_table_column(self, value): return ", ".join(value) -class ObjectFieldType(FieldType): +class RelatedObjectFilterFormMixin: + """Mixin providing shared get_filterform_field logic for object reference field types.""" + + _filterform_field_class = None + + def get_filterform_field(self, field, **kwargs): + content_type = ContentType.objects.get(pk=field.related_object_type_id) + if content_type.app_label == APP_LABEL: + from netbox_custom_objects.models import CustomObjectType + + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') + custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) + model = custom_object_type.get_model() + else: + model = content_type.model_class() + return self._filterform_field_class( + queryset=model.objects.all(), + required=False, + label=field, + selector=model._meta.app_label != APP_LABEL, + ) + + +class ObjectFieldType(RelatedObjectFilterFormMixin, FieldType): + _filterform_field_class = DynamicModelChoiceField + def get_model_field(self, field, **kwargs): content_type = ContentType.objects.get(pk=field.related_object_type_id) to_model = content_type.model # Extract our custom parameters and keep only Django field parameters field_kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')} - field_kwargs.update({"default": field.default, "unique": field.unique}) + field_kwargs.update({'default': field.default, 'unique': field.unique}) # Handle self-referential fields by using string references if content_type.app_label == APP_LABEL: from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) # Check if this is a self-referential field if custom_object_type.id == field.custom_object_type.id: # For self-referential fields, use LazyForeignKey to defer resolution - model_name = f"{APP_LABEL}.{custom_object_type.get_table_model_name(custom_object_type.id)}" + model_name = f'{APP_LABEL}.{custom_object_type.get_table_model_name(custom_object_type.id)}' # Generate a unique related_name to prevent reverse accessor conflicts table_model_name = field.custom_object_type.get_table_model_name(field.custom_object_type.id).lower() - related_name = f"{table_model_name}_{field.name}_set" + related_name = f'{table_model_name}_{field.name}_set' f = LazyForeignKey( model_name, null=True, blank=True, on_delete=models.CASCADE, related_name=related_name, - **field_kwargs + **field_kwargs, ) return f else: @@ -437,7 +442,7 @@ def get_model_field(self, field, **kwargs): # Generate a unique related_name to prevent reverse accessor conflicts table_model_name = field.custom_object_type.get_table_model_name(field.custom_object_type.id).lower() - related_name = f"{table_model_name}_{field.name}_set" + related_name = f'{table_model_name}_{field.name}_set' f = models.ForeignKey( model, null=True, blank=True, on_delete=models.CASCADE, related_name=related_name, **field_kwargs ) @@ -456,9 +461,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): # This is a custom object type from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) model = custom_object_type.get_model() @@ -482,30 +485,10 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): queryset=model.objects.all(), required=field.required, # Remove initial=field.default to allow Django to handle instance data properly - query_params=( - field.related_object_filter - if hasattr(field, "related_object_filter") - else None - ), + query_params=(field.related_object_filter if hasattr(field, 'related_object_filter') else None), selector=model._meta.app_label != APP_LABEL, ) - def get_filterform_field(self, field, **kwargs): - content_type = ContentType.objects.get(pk=field.related_object_type_id) - if content_type.app_label == APP_LABEL: - from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace("model", "") - custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) - model = custom_object_type.get_model() - else: - model = content_type.model_class() - return DynamicModelChoiceField( - queryset=model.objects.all(), - required=False, - label=field, - selector=model._meta.app_label != APP_LABEL, - ) - def render_table_column(self, value): return linkify(value) @@ -513,6 +496,7 @@ def get_serializer_field(self, field, **kwargs): related_model_class = field.related_object_type.model_class() if related_model_class._meta.app_label == APP_LABEL: from netbox_custom_objects.api.serializers import get_serializer_class + serializer = get_serializer_class(related_model_class, skip_object_fields=True) else: serializer = get_serializer_for_model(related_model_class) @@ -544,9 +528,9 @@ def get_prefetch_queryset(self, instances, queryset=None): queryset = self.get_queryset() # Get all the target IDs for these instances in a single query - through_queryset = self.through.objects.filter( - source_id__in=[obj.pk for obj in instances] - ).values_list("source_id", "target_id") + through_queryset = self.through.objects.filter(source_id__in=[obj.pk for obj in instances]).values_list( + 'source_id', 'target_id' + ) # Build a mapping of instance PKs to their related objects rel_obj_cache = {source_id: [] for source_id in [obj.pk for obj in instances]} @@ -562,17 +546,13 @@ def get_prefetch_queryset(self, instances, queryset=None): # Build the final cache mapping for source_id, target_ids in rel_obj_cache.items(): rel_obj_cache[source_id] = [ - target_objects[target_id] - for target_id in target_ids - if target_id in target_objects + target_objects[target_id] for target_id in target_ids if target_id in target_objects ] return ( target_queryset, # queryset containing all the related objects lambda obj: obj.pk, # function to get the related object ID - lambda obj: rel_obj_cache[ - obj.pk - ], # function to get the list of related objects + lambda obj: rel_obj_cache[obj.pk], # function to get the list of related objects False, # single related object (False for M2M) self.prefetch_cache_name, # cache name False, # is a descriptor (False for M2M) @@ -584,9 +564,7 @@ def get_queryset(self): # Join through the through table using a subquery qs = base_qs.filter( - pk__in=self.through.objects.filter(source_id=self.instance.pk).values_list( - "target_id", flat=True - ) + pk__in=self.through.objects.filter(source_id=self.instance.pk).values_list('target_id', flat=True) ) # Add default ordering by pk @@ -594,15 +572,11 @@ def get_queryset(self): def add(self, *objs): for obj in objs: - self.through.objects.get_or_create( - source_id=self.instance.pk, target_id=obj.pk - ) + self.through.objects.get_or_create(source_id=self.instance.pk, target_id=obj.pk) def remove(self, *objs): for obj in objs: - self.through.objects.filter( - source_id=self.instance.pk, target_id=obj.pk - ).delete() + self.through.objects.filter(source_id=self.instance.pk, target_id=obj.pk).delete() def clear(self): self.through.objects.filter(source_id=self.instance.pk).delete() @@ -675,7 +649,9 @@ def get_joining_columns(self, reverse_join=False): return ((self.m2m_field_name(), "id"),) -class MultiObjectFieldType(FieldType): +class MultiObjectFieldType(RelatedObjectFilterFormMixin, FieldType): + _filterform_field_class = DynamicModelMultipleChoiceField + def get_through_model(self, field, model_string): """ Creates a through model with deferred model references @@ -697,12 +673,9 @@ def get_through_model(self, field, model_string): # Check if this is a self-referential M2M content_type = ContentType.objects.get(pk=field.related_object_type_id) - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') is_self_referential = ( - content_type.app_label == APP_LABEL - and field.custom_object_type.id == custom_object_type_id + content_type.app_label == APP_LABEL and field.custom_object_type.id == custom_object_type_id ) attrs = { @@ -716,7 +689,7 @@ def get_through_model(self, field, model_string): db_column="source_id", ), "target": models.ForeignKey( - "self" if is_self_referential else model_string, + 'self' if is_self_referential else model_string, on_delete=models.CASCADE, related_name="+", db_column="target_id", @@ -731,35 +704,32 @@ def get_model_field(self, field, **kwargs): """ # Check if this is a self-referential M2M content_type = ContentType.objects.get(pk=field.related_object_type_id) - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') # Extract our custom parameters and keep only Django field parameters field_kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')} # Remove default from field_kwargs since ManyToManyField doesn't handle defaults the same way - field_kwargs.update({"unique": field.unique}) + field_kwargs.update({'unique': field.unique}) is_self_referential = ( - content_type.app_label == APP_LABEL - and field.custom_object_type.id == custom_object_type_id + content_type.app_label == APP_LABEL and field.custom_object_type.id == custom_object_type_id ) # For now, we'll create the through model with string references # and resolve them later in after_model_generation # TODO: Check whether later resolution of the model is actually necessary or can be passed as string - model_string = f"{field.related_object_type.app_label}.{field.related_object_type.model}" + model_string = f'{field.related_object_type.app_label}.{field.related_object_type.model}' through = self.get_through_model(field, model_string) # For self-referential fields, use 'self' as the target m2m_field = CustomManyToManyField( - to="self" if is_self_referential else model_string, + to='self' if is_self_referential else model_string, through=through, through_fields=("source", "target"), blank=True, related_name="+", related_query_name="+", - **field_kwargs + **field_kwargs, ) # Store metadata for later resolution @@ -779,9 +749,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): # This is a custom object type from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) model = custom_object_type.get_model(skip_object_fields=True) @@ -803,33 +771,13 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): return field_class( queryset=model.objects.all(), required=field.required, - query_params=( - field.related_object_filter - if hasattr(field, "related_object_filter") - else None - ), + query_params=(field.related_object_filter if hasattr(field, 'related_object_filter') else None), selector=model._meta.app_label != APP_LABEL, ) - def get_filterform_field(self, field, **kwargs): - content_type = ContentType.objects.get(pk=field.related_object_type_id) - if content_type.app_label == APP_LABEL: - from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace("model", "") - custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) - model = custom_object_type.get_model() - else: - model = content_type.model_class() - return DynamicModelMultipleChoiceField( - queryset=model.objects.all(), - required=False, - label=field, - selector=model._meta.app_label != APP_LABEL, - ) - def get_display_value(self, instance, field_name): field = getattr(instance, field_name) - return ", ".join(str(s) for s in field.all()) + return ', '.join(str(s) for s in field.all()) def get_table_column_field(self, field, **kwargs): return tables.ManyToManyColumn(linkify_item=True, orderable=False) @@ -838,6 +786,7 @@ def get_serializer_field(self, field, **kwargs): related_model_class = field.related_object_type.model_class() if related_model_class._meta.app_label == APP_LABEL: from netbox_custom_objects.api.serializers import get_serializer_class + serializer = get_serializer_class(related_model_class, skip_object_fields=True) else: serializer = get_serializer_for_model(related_model_class) @@ -855,8 +804,8 @@ def after_model_generation(self, instance, model, field_name): through_model = field.remote_field.through # Update both source and target fields to point to the same model - source_field = through_model._meta.get_field("source") - target_field = through_model._meta.get_field("target") + source_field = through_model._meta.get_field('source') + target_field = through_model._meta.get_field('target') # Resolve the foreign key fields to point to the actual model source_field.remote_field.model = model @@ -877,9 +826,7 @@ def after_model_generation(self, instance, model, field_name): if content_type.app_label == APP_LABEL: from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) # For self-referential fields, we need to resolve them to the current model @@ -898,7 +845,7 @@ def after_model_generation(self, instance, model, field_name): # Update through model's target field through_model = field.remote_field.through - source_field = through_model._meta.get_field("source") + source_field = through_model._meta.get_field('source') target_field = through_model._meta.get_field("target") # Source field should point to the current model @@ -926,12 +873,8 @@ def create_m2m_table(self, instance, model, field_name): if content_type.app_label == APP_LABEL: from netbox_custom_objects.models import CustomObjectType - custom_object_type_id = content_type.model.replace("table", "").replace( - "model", "" - ) - custom_object_type = CustomObjectType.objects.get( - pk=custom_object_type_id - ) + custom_object_type_id = content_type.model.replace('table', '').replace('model', '') + custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) to_model = custom_object_type.get_model() else: @@ -941,8 +884,8 @@ def create_m2m_table(self, instance, model, field_name): through = self.get_through_model(instance, model) # Update the through model's foreign key references - source_field = through._meta.get_field("source") - target_field = through._meta.get_field("target") + source_field = through._meta.get_field('source') + target_field = through._meta.get_field('target') # Source field should point to the current model source_field.remote_field.model = model diff --git a/netbox_custom_objects/filtersets.py b/netbox_custom_objects/filtersets.py index 35ea4b50..068c8462 100644 --- a/netbox_custom_objects/filtersets.py +++ b/netbox_custom_objects/filtersets.py @@ -1,9 +1,18 @@ import django_filters +from dataclasses import dataclass +from typing import Any, Dict, Optional, Type + from django.contrib.postgres.fields import ArrayField -from django.db.models import JSONField, Q +from django.db.models import JSONField, QuerySet, Q from extras.choices import CustomFieldTypeChoices from netbox.filtersets import NetBoxModelFilterSet +from utilities.filters import ( + MultiValueDateFilter, + MultiValueDateTimeFilter, + MultiValueDecimalFilter, + MultiValueNumberFilter, +) from .models import CustomObjectType @@ -13,6 +22,66 @@ ) +@dataclass +class FilterSpec: + """ + Declarative specification describing how a custom field type + should be translated into a django-filter Filter instance. + """ + + filter_class: Type[django_filters.Filter] + lookup_expr: Optional[str] = None + extra_kwargs: Optional[Dict[str, Any]] = None + + def build( + self, field_name: str, label: str, queryset: Optional[QuerySet] = None, **kwargs + ) -> django_filters.Filter: + """ + Instantiate and return a django-filter Filter. + Allows overriding defaults via **kwargs. + """ + filter_kwargs = { + 'field_name': field_name, + 'label': label, + } + + if self.lookup_expr: + filter_kwargs['lookup_expr'] = self.lookup_expr + + if queryset is not None: + filter_kwargs['queryset'] = queryset + + # Apply defaults from the spec + if self.extra_kwargs: + filter_kwargs.update(self.extra_kwargs) + + # Apply dynamic overrides (e.g. resolved choices) + filter_kwargs.update(kwargs) + + return self.filter_class(**filter_kwargs) + + +FIELD_TYPE_FILTERS = { + CustomFieldTypeChoices.TYPE_TEXT: FilterSpec(django_filters.CharFilter, lookup_expr='icontains'), + CustomFieldTypeChoices.TYPE_LONGTEXT: FilterSpec(django_filters.CharFilter, lookup_expr='icontains'), + CustomFieldTypeChoices.TYPE_INTEGER: FilterSpec(MultiValueNumberFilter, lookup_expr='exact'), + CustomFieldTypeChoices.TYPE_DECIMAL: FilterSpec(MultiValueDecimalFilter, lookup_expr='exact'), + CustomFieldTypeChoices.TYPE_BOOLEAN: FilterSpec(django_filters.BooleanFilter), + CustomFieldTypeChoices.TYPE_DATE: FilterSpec(MultiValueDateFilter, lookup_expr='exact'), + CustomFieldTypeChoices.TYPE_DATETIME: FilterSpec(MultiValueDateTimeFilter, lookup_expr='exact'), + CustomFieldTypeChoices.TYPE_URL: FilterSpec(django_filters.CharFilter, lookup_expr='icontains'), + CustomFieldTypeChoices.TYPE_JSON: FilterSpec(django_filters.CharFilter, lookup_expr='icontains'), + CustomFieldTypeChoices.TYPE_SELECT: FilterSpec( + django_filters.ChoiceFilter, extra_kwargs={'choices': lambda f: f.choices} + ), + CustomFieldTypeChoices.TYPE_MULTISELECT: FilterSpec( + django_filters.MultipleChoiceFilter, extra_kwargs={'choices': lambda f: f.choices} + ), + CustomFieldTypeChoices.TYPE_OBJECT: FilterSpec(django_filters.ModelChoiceFilter), + CustomFieldTypeChoices.TYPE_MULTIOBJECT: FilterSpec(django_filters.ModelMultipleChoiceFilter), +} + + class CustomObjectTypeFilterSet(NetBoxModelFilterSet): class Meta: model = CustomObjectType @@ -23,10 +92,40 @@ class Meta: ) +def build_filter_for_field(field) -> Optional[django_filters.Filter]: + if not (spec := FIELD_TYPE_FILTERS.get(field.type)): + return None + + queryset = None + if field.type in ( + CustomFieldTypeChoices.TYPE_OBJECT, + CustomFieldTypeChoices.TYPE_MULTIOBJECT, + ): + related_object_type = getattr(field, 'related_object_type', None) + if not related_object_type: + # Defensive guard: if data integrity is compromised and the related object type + # is missing, skip building a filter for this field rather than raising. + return None + queryset = related_object_type.model_class().objects.all() + + extra_kwargs = {} + if spec.extra_kwargs: + for key, value in spec.extra_kwargs.items(): + extra_kwargs[key] = value(field) if callable(value) else value + + return spec.build( + field_name=field.name, + label=field.label, + queryset=queryset, + **extra_kwargs, + ) + + def get_filterset_class(model): """ Create and return a filterset class for the given custom object model. """ + # Get standard fields from the model fields = [field.name for field in model._meta.fields] meta = type( @@ -70,10 +169,19 @@ def search(self, queryset, name, value): attrs = { "Meta": meta, - "__module__": "database.filtersets", + "__module__": "netbox_custom_objects.filtersets", "search": search, } + # For each custom field, add a corresponding filter. + # Multiobject (M2M) fields are handled separately below via the through-table approach. + for field in model.custom_object_type.fields.all(): + if field.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + continue + filter_instance = build_filter_for_field(field) + if filter_instance: + attrs[field.name] = filter_instance + # Add filters for M2M (multiobject) fields, which are not in model._meta.fields. # By the time get_filterset_class() is called (at request time), after_model_generation() # will have already resolved m2m_field.remote_field.model and .through to actual model @@ -92,6 +200,7 @@ def filter_m2m(self, queryset, name, value): target_id__in=ids ).values_list("source_id", flat=True) return queryset.filter(pk__in=source_ids) + filter_m2m.__name__ = f"filter_{fname}" return filter_m2m diff --git a/netbox_custom_objects/tests/test_filtersets.py b/netbox_custom_objects/tests/test_filtersets.py index e3de7016..d45c2e4a 100644 --- a/netbox_custom_objects/tests/test_filtersets.py +++ b/netbox_custom_objects/tests/test_filtersets.py @@ -1,31 +1,345 @@ +import datetime +from decimal import Decimal +from itertools import chain + +import django_filters +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.models import ContentType +from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel from django.test import TestCase +try: + from taggit.managers import TaggableManager +except ImportError: + TaggableManager = None + from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site from netbox_custom_objects.field_types import MultiObjectFieldType, ObjectFieldType -from netbox_custom_objects.filtersets import get_filterset_class -from netbox_custom_objects.models import CustomObjectTypeField -from utilities.forms.fields import ( - DynamicModelChoiceField, - DynamicModelMultipleChoiceField, -) +from netbox_custom_objects.filtersets import CustomObjectTypeFilterSet, get_filterset_class +from netbox_custom_objects.models import CustomObjectType, CustomObjectTypeField +from utilities.forms.fields import DynamicModelChoiceField, DynamicModelMultipleChoiceField from .base import CustomObjectsTestCase +EXEMPT_MODEL_FIELDS = ( + 'comments', + 'custom_field_data', + 'level', + 'lft', + 'rght', + 'tree_id', +) + + def _make_device_fixtures(suffix): """Create minimal DCIM fixtures needed to instantiate a Device.""" - site = Site.objects.create(name=f"Site {suffix}", slug=f"site-{suffix}") - mfr = Manufacturer.objects.create(name=f"Mfr {suffix}", slug=f"mfr-{suffix}") - dt = DeviceType.objects.create( - manufacturer=mfr, model=f"DT {suffix}", slug=f"dt-{suffix}" - ) - role = DeviceRole.objects.create( - name=f"Role {suffix}", slug=f"role-{suffix}", color="ff0000" - ) + site = Site.objects.create(name=f'Site {suffix}', slug=f'site-{suffix}') + mfr = Manufacturer.objects.create(name=f'Mfr {suffix}', slug=f'mfr-{suffix}') + dt = DeviceType.objects.create(manufacturer=mfr, model=f'DT {suffix}', slug=f'dt-{suffix}') + role = DeviceRole.objects.create(name=f'Role {suffix}', slug=f'role-{suffix}', color='ff0000') return site, dt, role +# --------------------------------------------------------------------------- +# BaseFilterSetTests mixin +# --------------------------------------------------------------------------- + + +class BaseFilterSetTests: + """ + Mixin that asserts every model field has a corresponding filter defined on its FilterSet. + Fields intentionally not filterable should be listed in ignore_fields. + """ + + ignore_fields = () + + def _get_filters_for_field(self, field): + if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel: + if field.related_model is ContentType: + return [(None, None)] + return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)] + + if type(field) in (ManyToManyField, ManyToManyRel): + if field.related_model is ContentType: + return [ + ('object_type', None), + ('object_type_id', django_filters.ModelMultipleChoiceFilter), + ] + related_name = field.related_model._meta.verbose_name.lower().replace(' ', '_') + return [(f'{related_name}_id', django_filters.ModelMultipleChoiceFilter)] + + if TaggableManager is not None and type(field) is TaggableManager: + return [('tag', None)] + + return [(field.name, None)] + + def test_missing_filters(self): + model = self.queryset.model + defined_filters = self.filterset.get_filters() + + for model_field in model._meta.get_fields(): + if model_field.name.startswith('_'): + continue + if model_field.name in chain(self.ignore_fields, EXEMPT_MODEL_FIELDS): + continue + if type(model_field) is ManyToOneRel: + continue + if type(model_field) in (GenericForeignKey, GenericRelation): + continue + + for filter_name, filter_class in self._get_filters_for_field(model_field): + if filter_name is None: + continue + self.assertIn( + filter_name, + defined_filters.keys(), + f'No filter defined for {filter_name} ({model_field.name})!', + ) + if filter_class is not None: + self.assertIsInstance( + defined_filters[filter_name], + filter_class, + f'Invalid filter class for {filter_name} (expected {filter_class})!', + ) + + +# --------------------------------------------------------------------------- +# CustomObjectTypeFilterSet (static) +# --------------------------------------------------------------------------- + + +class CustomObjectTypeFilterSetTestCase(CustomObjectsTestCase, TestCase, BaseFilterSetTests): + filterset = CustomObjectTypeFilterSet + # Fields intentionally not covered by CustomObjectTypeFilterSet + ignore_fields = ( + 'slug', + 'description', + 'verbose_name_plural', + ) + + @classmethod + def setUpTestData(cls): + CustomObjectType.objects.create(name='Type 1', slug='type-1') + CustomObjectType.objects.create(name='Type 2', slug='type-2', group_name='Group A') + CustomObjectType.objects.create(name='Type 3', slug='type-3', group_name='Group A') + + @property + def queryset(self): + return CustomObjectType.objects.all() + + def test_id(self): + params = {'id': list(CustomObjectType.objects.values_list('pk', flat=True)[:2])} + self.assertEqual(self.filterset(params, CustomObjectType.objects.all()).qs.count(), 2) + + def test_name(self): + params = {'name': ['Type 1', 'Type 2']} + self.assertEqual(self.filterset(params, CustomObjectType.objects.all()).qs.count(), 2) + + def test_group_name(self): + params = {'group_name': ['Group A']} + self.assertEqual(self.filterset(params, CustomObjectType.objects.all()).qs.count(), 2) + + def test_q(self): + params = {'q': 'Type 1'} + self.assertEqual(self.filterset(params, CustomObjectType.objects.all()).qs.count(), 1) + + +# --------------------------------------------------------------------------- +# Dynamic filterset — one field per supported type +# --------------------------------------------------------------------------- + + +class CustomObjectFilterSetTestCase(CustomObjectsTestCase, TestCase): + """ + Tests for dynamically generated filtersets on custom object instances. + Verifies that a filter for each supported field type is functional and + returns the correct results. Range filters (__lte/__gte) on date and numeric + fields are auto-generated by NetBoxModelFilterSet via get_additional_lookups(). + """ + + @classmethod + def setUpTestData(cls): + # Devices used for object/multiobject field tests + manufacturer = Manufacturer.objects.create(name='FS Manufacturer', slug='fs-manufacturer') + device_type = DeviceType.objects.create( + manufacturer=manufacturer, model='FS Device Type', slug='fs-device-type' + ) + role = DeviceRole.objects.create(name='FS Role', slug='fs-role', color='ff0000') + site = Site.objects.create(name='FS Site', slug='fs-site') + cls.device1 = Device.objects.create(name='FS Device 1', device_type=device_type, role=role, site=site) + cls.device2 = Device.objects.create(name='FS Device 2', device_type=device_type, role=role, site=site) + + choice_set = CustomObjectsTestCase.create_choice_set(name='FS Choice Set') + device_object_type = CustomObjectsTestCase.get_device_object_type() + + cls.cot = CustomObjectsTestCase.create_custom_object_type(name='FilterSetObject', slug='filterset-objects') + + for field_def in [ + {'name': 'text_field', 'label': 'Text Field', 'type': 'text'}, + {'name': 'longtext_field', 'label': 'Long Text Field', 'type': 'longtext'}, + {'name': 'int_field', 'label': 'Integer Field', 'type': 'integer'}, + {'name': 'decimal_field', 'label': 'Decimal Field', 'type': 'decimal'}, + {'name': 'bool_field', 'label': 'Boolean Field', 'type': 'boolean'}, + {'name': 'date_field', 'label': 'Date Field', 'type': 'date'}, + {'name': 'url_field', 'label': 'URL Field', 'type': 'url'}, + {'name': 'json_field', 'label': 'JSON Field', 'type': 'json'}, + ]: + CustomObjectsTestCase.create_custom_object_type_field(cls.cot, **field_def) + + CustomObjectsTestCase.create_custom_object_type_field( + cls.cot, name='select_field', label='Select Field', type='select', choice_set=choice_set + ) + CustomObjectsTestCase.create_custom_object_type_field( + cls.cot, name='device_field', label='Device Field', type='object', related_object_type=device_object_type + ) + CustomObjectsTestCase.create_custom_object_type_field( + cls.cot, + name='devices_field', + label='Devices Field', + type='multiobject', + related_object_type=device_object_type, + ) + + cls.model = cls.cot.get_model() + cls.filterset = get_filterset_class(cls.model) + + cls.obj1 = cls.model.objects.create( + text_field='Alpha value', + longtext_field='Alpha long text', + int_field=10, + decimal_field=Decimal('1.5000'), + bool_field=True, + date_field=datetime.date(2024, 1, 1), + url_field='https://alpha.example.com', + json_field={'tag': 'alpha'}, + select_field='choice1', + device_field=cls.device1, + ) + cls.obj2 = cls.model.objects.create( + text_field='Beta value', + longtext_field='Beta long text', + int_field=20, + decimal_field=Decimal('2.5000'), + bool_field=False, + date_field=datetime.date(2024, 6, 15), + url_field='https://beta.example.com', + json_field={'tag': 'beta'}, + select_field='choice2', + device_field=cls.device2, + ) + cls.obj3 = cls.model.objects.create( + text_field='Gamma value', + longtext_field='Gamma long text', + int_field=30, + decimal_field=Decimal('3.5000'), + bool_field=True, + date_field=datetime.date(2024, 12, 31), + url_field='https://gamma.example.com', + json_field={'tag': 'gamma'}, + select_field='choice1', + device_field=cls.device1, + ) + + cls.obj1.devices_field.add(cls.device1) + cls.obj2.devices_field.add(cls.device2) + cls.obj3.devices_field.add(cls.device1, cls.device2) + + @property + def queryset(self): + return self.model.objects.all() + + # --- Text types (icontains) --- + + def test_text_field(self): + params = {'text_field': 'alpha'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + def test_longtext_field(self): + params = {'longtext_field': 'beta'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + def test_url_field(self): + params = {'url_field': 'gamma'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + def test_json_field(self): + params = {'json_field': 'alpha'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + # --- Numeric types (exact + range lookups) --- + + def test_integer_field(self): + params = {'int_field': 20} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + def test_integer_field_lte(self): + params = {'int_field__lte': 20} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_integer_field_gte(self): + params = {'int_field__gte': 20} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_decimal_field(self): + params = {'decimal_field': '2.5'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + def test_decimal_field_lte(self): + params = {'decimal_field__lte': '2.5'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_decimal_field_gte(self): + params = {'decimal_field__gte': '2.5'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + # --- Boolean --- + + def test_boolean_field_true(self): + params = {'bool_field': True} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_boolean_field_false(self): + params = {'bool_field': False} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + # --- Date (exact + range lookups auto-generated by NetBoxModelFilterSet) --- + + def test_date_field(self): + params = {'date_field': '2024-01-01'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + + def test_date_field_lte(self): + # obj1 (2024-01-01) and obj2 (2024-06-15) are on or before 2024-06-15 + params = {'date_field__lte': '2024-06-15'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_date_field_gte(self): + # obj2 (2024-06-15) and obj3 (2024-12-31) are on or after 2024-06-15 + params = {'date_field__gte': '2024-06-15'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + # --- Choice --- + + def test_select_field(self): + # obj1 and obj3 have choice1 + params = {'select_field': 'choice1'} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + # --- Object references --- + + def test_object_field(self): + # obj1 and obj3 reference device1 + params = {'device_field': self.device1.pk} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_multiobject_field(self): + # obj2 and obj3 reference device2 + params = {'devices_field': [self.device2.pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + # --------------------------------------------------------------------------- # ObjectFieldType.get_filterform_field — form field shape # --------------------------------------------------------------------------- @@ -37,15 +351,15 @@ class ObjectFieldFilterFormFieldTestCase(CustomObjectsTestCase, TestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.cot = cls.create_custom_object_type(name="ObjFFTest", slug="obj-ff-test") + cls.cot = cls.create_custom_object_type(name='ObjFFTest', slug='obj-ff-test') cls.create_custom_object_type_field( - cls.cot, name="name", label="Name", type="text", primary=True, required=True + cls.cot, name='name', label='Name', type='text', primary=True, required=True ) cls.field = cls.create_custom_object_type_field( cls.cot, - name="device", - label="Device", - type="object", + name='device', + label='Device', + type='object', related_object_type=cls.get_device_object_type(), ) @@ -73,43 +387,39 @@ class ObjectFieldFiltersetTestCase(CustomObjectsTestCase, TestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.cot = cls.create_custom_object_type(name="ObjFSTest", slug="obj-fs-test") + cls.cot = cls.create_custom_object_type(name='ObjFSTest', slug='obj-fs-test') cls.create_custom_object_type_field( - cls.cot, name="name", label="Name", type="text", primary=True, required=True + cls.cot, name='name', label='Name', type='text', primary=True, required=True ) cls.create_custom_object_type_field( cls.cot, - name="device", - label="Device", - type="object", + name='device', + label='Device', + type='object', related_object_type=cls.get_device_object_type(), ) - site, dt, role = _make_device_fixtures("ofst") - cls.device1 = Device.objects.create( - name="Device OFS 1", site=site, device_type=dt, role=role - ) - cls.device2 = Device.objects.create( - name="Device OFS 2", site=site, device_type=dt, role=role - ) + site, dt, role = _make_device_fixtures('ofst') + cls.device1 = Device.objects.create(name='Device OFS 1', site=site, device_type=dt, role=role) + cls.device2 = Device.objects.create(name='Device OFS 2', site=site, device_type=dt, role=role) model = cls.cot.get_model() - cls.obj_d1 = model.objects.create(name="Obj D1", device=cls.device1) - cls.obj_d2 = model.objects.create(name="Obj D2", device=cls.device2) - cls.obj_none = model.objects.create(name="Obj None") + cls.obj_d1 = model.objects.create(name='Obj D1', device=cls.device1) + cls.obj_d2 = model.objects.create(name='Obj D2', device=cls.device2) + cls.obj_none = model.objects.create(name='Obj None') def _filterset(self, params): model = self.cot.get_model() return get_filterset_class(model)(params, model.objects.all()) def test_filter_returns_matching_object(self): - pks = list(self._filterset({"device": self.device1.pk}).qs.values_list("pk", flat=True)) + pks = list(self._filterset({'device': self.device1.pk}).qs.values_list('pk', flat=True)) self.assertIn(self.obj_d1.pk, pks) self.assertNotIn(self.obj_d2.pk, pks) self.assertNotIn(self.obj_none.pk, pks) def test_filter_different_value(self): - pks = list(self._filterset({"device": self.device2.pk}).qs.values_list("pk", flat=True)) + pks = list(self._filterset({'device': self.device2.pk}).qs.values_list('pk', flat=True)) self.assertIn(self.obj_d2.pk, pks) self.assertNotIn(self.obj_d1.pk, pks) @@ -128,15 +438,15 @@ class MultiObjectFieldFilterFormFieldTestCase(CustomObjectsTestCase, TestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.cot = cls.create_custom_object_type(name="MoFFTest", slug="mo-ff-test") + cls.cot = cls.create_custom_object_type(name='MoFFTest', slug='mo-ff-test') cls.create_custom_object_type_field( - cls.cot, name="name", label="Name", type="text", primary=True, required=True + cls.cot, name='name', label='Name', type='text', primary=True, required=True ) cls.field = cls.create_custom_object_type_field( cls.cot, - name="sites", - label="Sites", - type="multiobject", + name='sites', + label='Sites', + type='multiobject', related_object_type=cls.get_site_object_type(), ) @@ -164,36 +474,36 @@ class MultiObjectFieldFiltersetTestCase(CustomObjectsTestCase, TestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.cot = cls.create_custom_object_type(name="MoFSTest", slug="mo-fs-test") + cls.cot = cls.create_custom_object_type(name='MoFSTest', slug='mo-fs-test') cls.create_custom_object_type_field( - cls.cot, name="name", label="Name", type="text", primary=True, required=True + cls.cot, name='name', label='Name', type='text', primary=True, required=True ) cls.create_custom_object_type_field( cls.cot, - name="sites", - label="Sites", - type="multiobject", + name='sites', + label='Sites', + type='multiobject', related_object_type=cls.get_site_object_type(), ) - cls.site1 = Site.objects.create(name="Site MOFS 1", slug="site-mofs-1") - cls.site2 = Site.objects.create(name="Site MOFS 2", slug="site-mofs-2") + cls.site1 = Site.objects.create(name='Site MOFS 1', slug='site-mofs-1') + cls.site2 = Site.objects.create(name='Site MOFS 2', slug='site-mofs-2') model = cls.cot.get_model() - cls.obj_s1 = model.objects.create(name="Obj S1") + cls.obj_s1 = model.objects.create(name='Obj S1') cls.obj_s1.sites.add(cls.site1) - cls.obj_s2 = model.objects.create(name="Obj S2") + cls.obj_s2 = model.objects.create(name='Obj S2') cls.obj_s2.sites.add(cls.site2) - cls.obj_both = model.objects.create(name="Obj Both") + cls.obj_both = model.objects.create(name='Obj Both') cls.obj_both.sites.add(cls.site1, cls.site2) - cls.obj_none = model.objects.create(name="Obj None") + cls.obj_none = model.objects.create(name='Obj None') def _filterset(self, params): model = self.cot.get_model() return get_filterset_class(model)(params, model.objects.all()) def test_filter_single_site_returns_linked_objects(self): - pks = list(self._filterset({"sites": [self.site1.pk]}).qs.values_list("pk", flat=True)) + pks = list(self._filterset({'sites': [self.site1.pk]}).qs.values_list('pk', flat=True)) self.assertIn(self.obj_s1.pk, pks) self.assertIn(self.obj_both.pk, pks) self.assertNotIn(self.obj_s2.pk, pks) @@ -202,7 +512,7 @@ def test_filter_single_site_returns_linked_objects(self): def test_filter_multiple_sites_returns_union(self): # OR semantics: any object linked to site1 or site2 pks = list( - self._filterset({"sites": [self.site1.pk, self.site2.pk]}).qs.values_list("pk", flat=True) + self._filterset({'sites': [self.site1.pk, self.site2.pk]}).qs.values_list('pk', flat=True) ) self.assertIn(self.obj_s1.pk, pks) self.assertIn(self.obj_s2.pk, pks) @@ -211,12 +521,11 @@ def test_filter_multiple_sites_returns_union(self): def test_filter_multiple_sites_no_duplicates(self): # obj_both is linked to both sites but should appear only once - qs = self._filterset({"sites": [self.site1.pk, self.site2.pk]}).qs - obj_both_count = qs.filter(pk=self.obj_both.pk).count() - self.assertEqual(obj_both_count, 1) + qs = self._filterset({'sites': [self.site1.pk, self.site2.pk]}).qs + self.assertEqual(qs.filter(pk=self.obj_both.pk).count(), 1) def test_filter_other_site(self): - pks = list(self._filterset({"sites": [self.site2.pk]}).qs.values_list("pk", flat=True)) + pks = list(self._filterset({'sites': [self.site2.pk]}).qs.values_list('pk', flat=True)) self.assertIn(self.obj_s2.pk, pks) self.assertIn(self.obj_both.pk, pks) self.assertNotIn(self.obj_s1.pk, pks) @@ -236,42 +545,34 @@ class CustomObjectTargetObjectFieldTestCase(CustomObjectsTestCase, TestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - # Target COT - cls.target_cot = cls.create_custom_object_type( - name="TargetObj", slug="target-obj" - ) + cls.target_cot = cls.create_custom_object_type(name='TargetObj', slug='target-obj') cls.create_custom_object_type_field( - cls.target_cot, name="name", label="Name", type="text", primary=True, required=True + cls.target_cot, name='name', label='Name', type='text', primary=True, required=True ) - # Source COT with object field → target COT - cls.source_cot = cls.create_custom_object_type( - name="SourceObj", slug="source-obj" - ) + cls.source_cot = cls.create_custom_object_type(name='SourceObj', slug='source-obj') cls.create_custom_object_type_field( - cls.source_cot, name="name", label="Name", type="text", primary=True, required=True + cls.source_cot, name='name', label='Name', type='text', primary=True, required=True ) cls.create_custom_object_type_field( cls.source_cot, - name="related", - label="Related", - type="object", + name='related', + label='Related', + type='object', related_object_type=cls.target_cot.object_type, ) target_model = cls.target_cot.get_model() - cls.target1 = target_model.objects.create(name="Target 1") - cls.target2 = target_model.objects.create(name="Target 2") + cls.target1 = target_model.objects.create(name='Target 1') + cls.target2 = target_model.objects.create(name='Target 2') source_model = cls.source_cot.get_model() - cls.source_t1 = source_model.objects.create(name="Source T1", related=cls.target1) - cls.source_t2 = source_model.objects.create(name="Source T2", related=cls.target2) - cls.source_none = source_model.objects.create(name="Source None") + cls.source_t1 = source_model.objects.create(name='Source T1', related=cls.target1) + cls.source_t2 = source_model.objects.create(name='Source T2', related=cls.target2) + cls.source_none = source_model.objects.create(name='Source None') def _field(self): - return CustomObjectTypeField.objects.get( - custom_object_type=self.source_cot, name="related" - ) + return CustomObjectTypeField.objects.get(custom_object_type=self.source_cot, name='related') def test_filterform_field_returns_dynamic_model_choice_field(self): form_field = ObjectFieldType().get_filterform_field(self._field()) @@ -280,27 +581,20 @@ def test_filterform_field_returns_dynamic_model_choice_field(self): def test_filterform_field_queryset_points_at_target_model(self): form_field = ObjectFieldType().get_filterform_field(self._field()) target_model = self.target_cot.get_model() - self.assertEqual( - form_field.queryset.model._meta.db_table, - target_model._meta.db_table, - ) + self.assertEqual(form_field.queryset.model._meta.db_table, target_model._meta.db_table) def test_filter_by_custom_object_target(self): source_model = self.source_cot.get_model() - fs = get_filterset_class(source_model)( - {"related": self.target1.pk}, source_model.objects.all() - ) - pks = list(fs.qs.values_list("pk", flat=True)) + fs = get_filterset_class(source_model)({'related': self.target1.pk}, source_model.objects.all()) + pks = list(fs.qs.values_list('pk', flat=True)) self.assertIn(self.source_t1.pk, pks) self.assertNotIn(self.source_t2.pk, pks) self.assertNotIn(self.source_none.pk, pks) def test_filter_other_target(self): source_model = self.source_cot.get_model() - fs = get_filterset_class(source_model)( - {"related": self.target2.pk}, source_model.objects.all() - ) - pks = list(fs.qs.values_list("pk", flat=True)) + fs = get_filterset_class(source_model)({'related': self.target2.pk}, source_model.objects.all()) + pks = list(fs.qs.values_list('pk', flat=True)) self.assertIn(self.source_t2.pk, pks) self.assertNotIn(self.source_t1.pk, pks) @@ -316,46 +610,38 @@ class CustomObjectTargetMultiObjectFieldTestCase(CustomObjectsTestCase, TestCase @classmethod def setUpTestData(cls): super().setUpTestData() - # Target COT - cls.target_cot = cls.create_custom_object_type( - name="TargetMObj", slug="target-mobj" - ) + cls.target_cot = cls.create_custom_object_type(name='TargetMObj', slug='target-mobj') cls.create_custom_object_type_field( - cls.target_cot, name="name", label="Name", type="text", primary=True, required=True + cls.target_cot, name='name', label='Name', type='text', primary=True, required=True ) - # Source COT with multiobject field → target COT - cls.source_cot = cls.create_custom_object_type( - name="SourceMObj", slug="source-mobj" - ) + cls.source_cot = cls.create_custom_object_type(name='SourceMObj', slug='source-mobj') cls.create_custom_object_type_field( - cls.source_cot, name="name", label="Name", type="text", primary=True, required=True + cls.source_cot, name='name', label='Name', type='text', primary=True, required=True ) cls.create_custom_object_type_field( cls.source_cot, - name="related_items", - label="Related Items", - type="multiobject", + name='related_items', + label='Related Items', + type='multiobject', related_object_type=cls.target_cot.object_type, ) target_model = cls.target_cot.get_model() - cls.target1 = target_model.objects.create(name="Target M1") - cls.target2 = target_model.objects.create(name="Target M2") + cls.target1 = target_model.objects.create(name='Target M1') + cls.target2 = target_model.objects.create(name='Target M2') source_model = cls.source_cot.get_model() - cls.source_t1 = source_model.objects.create(name="Source MT1") + cls.source_t1 = source_model.objects.create(name='Source MT1') cls.source_t1.related_items.add(cls.target1) - cls.source_t2 = source_model.objects.create(name="Source MT2") + cls.source_t2 = source_model.objects.create(name='Source MT2') cls.source_t2.related_items.add(cls.target2) - cls.source_both = source_model.objects.create(name="Source MBoth") + cls.source_both = source_model.objects.create(name='Source MBoth') cls.source_both.related_items.add(cls.target1, cls.target2) - cls.source_none = source_model.objects.create(name="Source MNone") + cls.source_none = source_model.objects.create(name='Source MNone') def _field(self): - return CustomObjectTypeField.objects.get( - custom_object_type=self.source_cot, name="related_items" - ) + return CustomObjectTypeField.objects.get(custom_object_type=self.source_cot, name='related_items') def test_filterform_field_returns_dynamic_model_multiple_choice_field(self): form_field = MultiObjectFieldType().get_filterform_field(self._field()) @@ -364,17 +650,14 @@ def test_filterform_field_returns_dynamic_model_multiple_choice_field(self): def test_filterform_field_queryset_points_at_target_model(self): form_field = MultiObjectFieldType().get_filterform_field(self._field()) target_model = self.target_cot.get_model() - self.assertEqual( - form_field.queryset.model._meta.db_table, - target_model._meta.db_table, - ) + self.assertEqual(form_field.queryset.model._meta.db_table, target_model._meta.db_table) def test_filter_single_target(self): source_model = self.source_cot.get_model() fs = get_filterset_class(source_model)( - {"related_items": [self.target1.pk]}, source_model.objects.all() + {'related_items': [self.target1.pk]}, source_model.objects.all() ) - pks = list(fs.qs.values_list("pk", flat=True)) + pks = list(fs.qs.values_list('pk', flat=True)) self.assertIn(self.source_t1.pk, pks) self.assertIn(self.source_both.pk, pks) self.assertNotIn(self.source_t2.pk, pks) @@ -383,9 +666,9 @@ def test_filter_single_target(self): def test_filter_multiple_targets_returns_union(self): source_model = self.source_cot.get_model() fs = get_filterset_class(source_model)( - {"related_items": [self.target1.pk, self.target2.pk]}, source_model.objects.all() + {'related_items': [self.target1.pk, self.target2.pk]}, source_model.objects.all() ) - pks = list(fs.qs.values_list("pk", flat=True)) + pks = list(fs.qs.values_list('pk', flat=True)) self.assertIn(self.source_t1.pk, pks) self.assertIn(self.source_t2.pk, pks) self.assertIn(self.source_both.pk, pks) @@ -394,6 +677,6 @@ def test_filter_multiple_targets_returns_union(self): def test_filter_multiple_targets_no_duplicates(self): source_model = self.source_cot.get_model() qs = get_filterset_class(source_model)( - {"related_items": [self.target1.pk, self.target2.pk]}, source_model.objects.all() + {'related_items': [self.target1.pk, self.target2.pk]}, source_model.objects.all() ).qs self.assertEqual(qs.filter(pk=self.source_both.pk).count(), 1)