diff --git a/netbox_custom_objects/api/serializers.py b/netbox_custom_objects/api/serializers.py index 48a5ac10..1c73cba3 100644 --- a/netbox_custom_objects/api/serializers.py +++ b/netbox_custom_objects/api/serializers.py @@ -4,6 +4,8 @@ from core.models import ObjectType from django.contrib.contenttypes.models import ContentType +from django.urls import NoReverseMatch +from django.utils.translation import gettext_lazy as _ from extras.choices import CustomFieldTypeChoices from netbox.api.serializers import NetBoxModelSerializer from rest_framework import serializers @@ -39,12 +41,93 @@ class Meta: ) +class PolymorphicObjectSerializerField(serializers.Field): + """ + Serializer field for polymorphic GenericForeignKey Object fields. + On read: returns a nested object representation with _content_type annotation. + On write: accepts {"content_type_id": X, "object_id": Y} or + {"app_label": "...", "model": "...", "object_id": Y}. + ``"id"`` is accepted as an alias for ``"object_id"`` so that the + dict emitted by ``to_representation`` (which uses ``"id"``) can be + round-tripped directly as write input. When both keys are present + ``"object_id"`` takes precedence; ``"id"`` is ignored. + For many=True (MultiObject polymorphic), wrap in a ListSerializer automatically. + + Pass ``allowed_content_type_ids`` (a set of ContentType PKs) to restrict which + object types may be submitted. Unrecognised types are rejected with HTTP 400. + """ + + def __init__(self, allowed_content_type_ids=None, **kwargs): + self.allowed_content_type_ids = allowed_content_type_ids + super().__init__(**kwargs) + + def to_representation(self, value): + if value is None: + return None + ct = ContentType.objects.get_for_model(value) + return { + "_content_type": f"{ct.app_label}.{ct.model}", + "id": value.pk, + "display": str(value), + } + + def to_internal_value(self, data): + if not isinstance(data, dict): + raise serializers.ValidationError(_("Expected a dict with object reference.")) + + # Resolve ContentType + try: + if "content_type_id" in data: + ct = ContentType.objects.get(pk=data["content_type_id"]) + elif "app_label" in data and "model" in data: + ct = ContentType.objects.get(app_label=data["app_label"], model=data["model"]) + else: + raise serializers.ValidationError( + _("Must provide content_type_id or (app_label + model).") + ) + except ContentType.DoesNotExist: + raise serializers.ValidationError(_("Invalid content type.")) from None + + if ( + self.allowed_content_type_ids is not None + and ct.id not in self.allowed_content_type_ids + ): + raise serializers.ValidationError( + _("Object type '%(app_label)s.%(model)s' is not allowed for this field.") + % {"app_label": ct.app_label, "model": ct.model} + ) + + model_class = ct.model_class() + if model_class is None: + raise serializers.ValidationError(_("Cannot resolve the specified object type.")) + + obj_id = data.get("object_id") if "object_id" in data else data.get("id") + if obj_id is None: + raise serializers.ValidationError(_("Must provide object_id.")) + + try: + return model_class.objects.get(pk=obj_id) + except (model_class.DoesNotExist, ValueError, TypeError, OverflowError): + raise serializers.ValidationError(_("No matching object found.")) from None + + class CustomObjectTypeFieldSerializer(NetBoxModelSerializer): url = serializers.HyperlinkedIdentityField( view_name="plugins-api:netbox_custom_objects-api:customobjecttypefield-detail" ) - app_label = serializers.CharField(required=False) - model = serializers.CharField(required=False) + app_label = serializers.CharField(required=False, write_only=True) + model = serializers.CharField(required=False, write_only=True) + # Read-only nested representation of the single related object type (non-polymorphic) + related_object_type = serializers.SerializerMethodField() + # Read-only nested representation of multiple allowed types (polymorphic) + related_object_types = serializers.SerializerMethodField() + # For polymorphic fields: list of {app_label, model} dicts + related_object_types_input = serializers.ListField( + child=serializers.DictField(), + required=False, + write_only=True, + help_text="List of {app_label, model} dicts for polymorphic field types", + ) class Meta: model = CustomObjectTypeField @@ -64,11 +147,14 @@ class Meta: "validation_regex", "validation_minimum", "validation_maximum", + "is_polymorphic", "related_object_type", + "related_object_types", "related_object_filter", "related_name", "app_label", "model", + "related_object_types_input", "group_name", "search_weight", "filter_logic", @@ -79,51 +165,111 @@ class Meta: "comments", ) + def _resolve_object_type(self, app_label, model): + """Resolve a single app_label+model pair to an ObjectType, handling aliases.""" + if app_label == _PUBLIC_APP_LABEL: + app_label = constants.APP_LABEL + if app_label == constants.APP_LABEL and model and not _TABLE_MODEL_PATTERN.match(model): + try: + cot = CustomObjectType.objects.get(slug=model) + model = CustomObjectType.get_table_model_name(cot.id).lower() + except CustomObjectType.DoesNotExist: + raise ValidationError(_("Invalid custom object type slug.")) + try: + return ObjectType.objects.get(app_label=app_label, model=model) + except ObjectType.DoesNotExist: + raise ValidationError( + _("Must provide a valid app_label and model for the object field type.") + ) + def validate(self, attrs): + # Guard immutable fields on existing instances. + if self.instance and self.instance.pk: + if "is_polymorphic" in attrs and bool(attrs["is_polymorphic"]) != bool(self.instance.is_polymorphic): + raise ValidationError( + {"is_polymorphic": _("Cannot change the polymorphic flag after field creation.")} + ) + if attrs.get("related_object_types_input") is not None: + # Resolve aliases (public app_label, COT slug as model name) before + # comparing so that a PUT/PATCH round-tripping the same logical types + # using alias forms is not rejected as a change. + # If resolution raises ValidationError here (invalid type in the + # payload) we skip the immutability guard — the error will surface + # again when the same entry is resolved in the normal validation path. + try: + resolved_incoming = [ + self._resolve_object_type( + entry.get("app_label", ""), entry.get("model", "") + ) + for entry in attrs["related_object_types_input"] + ] + except ValidationError: + resolved_incoming = None + + if resolved_incoming is not None: + incoming = frozenset( + (ot.app_label, ot.model) for ot in resolved_incoming + ) + existing = frozenset( + (ot.app_label, ot.model) + for ot in self.instance.related_object_types.all() + ) + if incoming != existing: + raise ValidationError( + {"related_object_types_input": _( + "Cannot change allowed object types after field creation." + )} + ) + if attrs.get("app_label") or attrs.get("model"): + raise ValidationError( + _("Cannot change the related object type after field creation.") + ) + app_label = attrs.pop("app_label", None) model = attrs.pop("model", None) - if attrs["type"] in [ + related_object_types_input = attrs.pop("related_object_types_input", None) + is_polymorphic = attrs.get("is_polymorphic", False) + + field_type = attrs.get("type") + + if field_type in [ CustomFieldTypeChoices.TYPE_OBJECT, CustomFieldTypeChoices.TYPE_MULTIOBJECT, ]: - # Allow the public URL slug "custom-objects" as an alias for the internal app label - if app_label == _PUBLIC_APP_LABEL: - app_label = constants.APP_LABEL - - # When targeting custom objects, allow the CustomObjectType slug as the model name - if app_label == constants.APP_LABEL and model and not _TABLE_MODEL_PATTERN.match(model): - try: - cot = CustomObjectType.objects.get(slug=model) - model = CustomObjectType.get_table_model_name(cot.id).lower() - except CustomObjectType.DoesNotExist: + if is_polymorphic: + # Polymorphic: resolve from related_object_types_input list + if related_object_types_input: + resolved = [] + for entry in related_object_types_input: + al = entry.get("app_label", "") + m = entry.get("model", "") + resolved.append(self._resolve_object_type(al, m)) + attrs["related_object_types"] = resolved + elif not attrs.get("related_object_types"): raise ValidationError( - "Must provide valid app_label and model for object field type." + _("Polymorphic object fields require related_object_types_input or related_object_types.") + ) + else: + # Non-polymorphic: resolve single type from app_label+model or related_object_type + if app_label or model: + attrs["related_object_type"] = self._resolve_object_type( + app_label or "", model or "" + ) + elif not attrs.get("related_object_type"): + raise ValidationError( + _("Must provide app_label and model (or related_object_type) for object field type.") ) - try: - attrs["related_object_type"] = ObjectType.objects.get( - app_label=app_label, model=model - ) - except ObjectType.DoesNotExist: - raise ValidationError( - "Must provide valid app_label and model for object field type." - ) - if attrs["type"] in [ + if field_type in [ CustomFieldTypeChoices.TYPE_SELECT, CustomFieldTypeChoices.TYPE_MULTISELECT, ]: if not attrs.get("choice_set", None): raise ValidationError( - "Must provide choice_set with valid PK for select field type." + _("Must provide choice_set with valid PK for select field type.") ) return super().validate(attrs) - def create(self, validated_data): - """ - Record the user who created the Custom Object as its owner. - """ - return super().create(validated_data) - def get_related_object_type(self, obj): if obj.related_object_type: return dict( @@ -131,6 +277,13 @@ def get_related_object_type(self, obj): app_label=obj.related_object_type.app_label, model=obj.related_object_type.model, ) + return None + + def get_related_object_types(self, obj): + return [ + dict(id=ot.id, app_label=ot.app_label, model=ot.model) + for ot in obj.related_object_types.all() + ] class CustomObjectTypeSerializer(NetBoxModelSerializer): @@ -171,9 +324,6 @@ def get_table_model_name(self, obj): def get_object_type_name(self, obj): return f"{constants.APP_LABEL}.{obj.get_table_model_name(obj.id).lower()}" - def create(self, validated_data): - return super().create(validated_data) - # TODO: Remove or reduce to a stub (not needed as all custom object serializers are generated via get_serializer_class) class CustomObjectSerializer(NetBoxModelSerializer): @@ -207,9 +357,6 @@ class Meta: def get_display(self, obj): return f"{obj.custom_object_type}: {obj.name}" - def validate(self, attrs): - return super().validate(attrs) - def update_relation_fields(self, instance): # TODO: Implement this pass @@ -227,10 +374,9 @@ def update(self, instance, validated_data): def get_url(self, obj): """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. + Given an object, return the URL that hyperlinks to the object, or None + if the URL cannot be resolved (e.g. the COT slug has changed since the + object was serialised, or the URL conf is misconfigured). """ # Unsaved objects will not yet have a valid URL. if hasattr(obj, "pk") and obj.pk in (None, ""): @@ -244,7 +390,10 @@ def get_url(self, obj): } request = self.context["request"] format = self.context.get("format") - return reverse(view_name, kwargs=kwargs, request=request, format=format) + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + return None def get_field_data(self, obj): result = {} @@ -252,6 +401,12 @@ def get_field_data(self, obj): def get_serializer_class(model, skip_object_fields=False): + # This function is intentionally not cached at the serializer level. + # It is called per-request (via CustomObjectViewSet.get_serializer_class → + # get_model_with_serializer), and the model itself is cache-invalidated on + # field post_save/pre_delete via clear_model_cache(). Keeping serializer + # generation fresh ensures _poly_obj_fields/_poly_m2m_fields always reflect + # the current set of polymorphic fields without a separate invalidation path. model_fields = model.custom_object_type.fields.all() # Create field list including all necessary fields @@ -286,7 +441,8 @@ def get_serializer_class(model, skip_object_fields=False): ) def get_url(self, obj): - """Generate the API URL for this object""" + """Generate the API URL for this object, or None if the URL cannot be + resolved (e.g. the COT slug changed since the object was serialized).""" if hasattr(obj, "pk") and obj.pk in (None, ""): return None @@ -298,12 +454,27 @@ def get_url(self, obj): } request = self.context["request"] format = self.context.get("format") - return reverse(view_name, kwargs=kwargs, request=request, format=format) + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + return None def get_display(self, obj): """Get display representation of the object""" return str(obj) + # Collect polymorphic field names for special handling in create/update + _poly_obj_fields = { + f.name for f in model.custom_object_type.fields.filter( + type=CustomFieldTypeChoices.TYPE_OBJECT, is_polymorphic=True + ) + } + _poly_m2m_fields = { + f.name for f in model.custom_object_type.fields.filter( + type=CustomFieldTypeChoices.TYPE_MULTIOBJECT, is_polymorphic=True + ) + } + def get__context(self, obj): """Return context field values as a nested display object for APISelect secondary text.""" context_parts = [] @@ -328,6 +499,18 @@ def create(self, validated_data): if relation_info.to_many and (field_name in validated_data): many_to_many[field_name] = validated_data.pop(field_name) + # Pop polymorphic GFK fields (set after instance creation via descriptor) + poly_gfk = {} + for field_name in _poly_obj_fields: + if field_name in validated_data: + poly_gfk[field_name] = validated_data.pop(field_name) + + # Pop polymorphic M2M fields (set after instance creation via manager) + poly_m2m = {} + for field_name in _poly_m2m_fields: + if field_name in validated_data: + poly_m2m[field_name] = validated_data.pop(field_name) + instance = ModelClass._default_manager.create(**validated_data) if many_to_many: @@ -335,12 +518,33 @@ def create(self, validated_data): field = getattr(instance, field_name) field.set(value) + for field_name, value in poly_gfk.items(): + setattr(instance, field_name, value) + if poly_gfk: + instance.save() + + for field_name, value in poly_m2m.items(): + mgr = getattr(instance, field_name) + mgr.set(value) + return instance # Stock DRF update() with custom field.set() for M2M def update(self, instance, validated_data): info = model_meta.get_field_info(instance) + # Pop polymorphic GFK fields + poly_gfk = {} + for field_name in _poly_obj_fields: + if field_name in validated_data: + poly_gfk[field_name] = validated_data.pop(field_name) + + # Pop polymorphic M2M fields + poly_m2m = {} + for field_name in _poly_m2m_fields: + if field_name in validated_data: + poly_m2m[field_name] = validated_data.pop(field_name) + m2m_fields = [] for attr, value in validated_data.items(): if attr in info.relations and info.relations[attr].to_many: @@ -348,14 +552,37 @@ def update(self, instance, validated_data): else: setattr(instance, attr, value) + for field_name, value in poly_gfk.items(): + setattr(instance, field_name, value) + instance.save() for attr, value in m2m_fields: field = getattr(instance, attr) field.set(value, clear=True) + for field_name, value in poly_m2m.items(): + mgr = getattr(instance, field_name) + mgr.set(value, clear=True) + return instance + def validate(self, data): + # NetBoxModelSerializer.validate() calls Model(**attrs) to check field + # values. Polymorphic GFK and M2M fields are not real Django model fields, + # so they'd cause a TypeError. Pop them before delegating to the parent, + # then restore them afterward. + # super() is unavailable here because this function is defined outside a + # class body (no __class__ cell). The generated class has a single base + # (NetBoxModelSerializer), so calling it directly is equivalent. + saved = {} + for field_name in (*_poly_obj_fields, *_poly_m2m_fields): + if field_name in data: + saved[field_name] = data.pop(field_name) + data = NetBoxModelSerializer.validate(self, data) + data.update(saved) + return data + # Create basic attributes for the serializer attrs = { "Meta": meta, @@ -366,6 +593,7 @@ def update(self, instance, validated_data): "get_display": get_display, "create": create, "update": update, + "validate": validate, } if has_context_fields: @@ -381,9 +609,22 @@ def update(self, instance, validated_data): try: attrs[field.name] = field_type.get_serializer_field(field) except NotImplementedError: + # Field type intentionally has no serializer representation; omit it. logger.debug( - "serializer: {} field is not implemented; using a default serializer field".format(field.name) + "serializer: field %r (type %r) has no serializer implementation; skipping", + field.name, field.type, + ) + except Exception as exc: + # Unexpected error (e.g. ContentType.DoesNotExist from a deleted + # ContentType row). Fall back to a permissive JSONField so the + # serializer remains functional and the error doesn't surface as a + # 500 to the caller. Log at WARNING so it's visible in production. + logger.warning( + "serializer: failed to build serializer field for %r (type %r): %s; " + "falling back to JSONField", + field.name, field.type, exc, ) + attrs[field.name] = serializers.JSONField(required=False, allow_null=True) serializer_name = f"{model._meta.object_name}Serializer" serializer = type( diff --git a/netbox_custom_objects/api/views.py b/netbox_custom_objects/api/views.py index 3ce1107d..8d343d96 100644 --- a/netbox_custom_objects/api/views.py +++ b/netbox_custom_objects/api/views.py @@ -31,7 +31,7 @@ def get_view_name(self): class CustomObjectTypeViewSet(ModelViewSet): - queryset = CustomObjectType.objects.all() + queryset = CustomObjectType.objects.prefetch_related('fields__related_object_types') serializer_class = serializers.CustomObjectTypeSerializer @@ -120,7 +120,7 @@ def perform_destroy(self, instance): class CustomObjectTypeFieldViewSet(ModelViewSet): - queryset = CustomObjectTypeField.objects.all() + queryset = CustomObjectTypeField.objects.prefetch_related('related_object_types') serializer_class = serializers.CustomObjectTypeFieldSerializer diff --git a/netbox_custom_objects/field_types.py b/netbox_custom_objects/field_types.py index 3b881f2c..77b0b11c 100644 --- a/netbox_custom_objects/field_types.py +++ b/netbox_custom_objects/field_types.py @@ -1,12 +1,16 @@ +import hashlib import json +import logging import django_tables2 as tables from django import forms from django.apps import apps +from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.contrib.postgres.fields import ArrayField +from django.core.exceptions import FieldDoesNotExist from django.core.validators import RegexValidator -from django.db import models +from django.db import connection, models from django.db.models.fields.related import ForeignKey, ManyToManyDescriptor from django.db.models.manager import Manager from django.utils.html import escape @@ -33,6 +37,27 @@ from netbox_custom_objects.constants import APP_LABEL from netbox_custom_objects.utilities import extract_cot_id_from_model_name, generate_model +logger = logging.getLogger(__name__) + +# PostgreSQL's hard limit for identifier names is 63 bytes. +_PG_MAX_IDENTIFIER_LEN = 63 + + +def _safe_index_name(full_name: str) -> str: + """ + Return a DB-safe index name that fits within PostgreSQL's 63-char identifier limit. + + If the full name fits, it is returned unchanged. If it is too long, the name is + truncated and an 8-character MD5 digest of the *full* name is appended so that + different long names with a shared prefix do not collide. + """ + if len(full_name) <= _PG_MAX_IDENTIFIER_LEN: + return full_name + digest = hashlib.md5(full_name.encode()).hexdigest()[:8] + # Reserve 9 chars for "_" + 8-char digest; strip trailing underscores from the prefix. + prefix = full_name[:_PG_MAX_IDENTIFIER_LEN - 9].rstrip("_") + return f"{prefix}_{digest}" + class LazyForeignKey(ForeignKey): """ @@ -395,6 +420,23 @@ def render_table_column(self, value): class ObjectFieldType(FieldType): def get_model_field(self, field, **kwargs): + if field.is_polymorphic: + # Polymorphic Object: two concrete columns + one virtual GFK descriptor + ct_field_name = f"{field.name}_content_type" + oid_field_name = f"{field.name}_object_id" + return { + ct_field_name: models.ForeignKey( + "contenttypes.ContentType", + null=True, + blank=True, + on_delete=models.SET_NULL, + related_name="+", + db_column=f"{field.name}_content_type_id", + ), + oid_field_name: models.PositiveBigIntegerField(null=True, blank=True), + field.name: GenericForeignKey(ct_field_name, oid_field_name), + } + content_type = ContentType.objects.get(pk=field.related_object_type_id) to_model = content_type.model @@ -460,6 +502,12 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): For custom objects, uses CustomObjectDynamicModelChoiceField. For regular NetBox objects, uses DynamicModelChoiceField. """ + if field.is_polymorphic: + # Polymorphic form field not yet supported in the UI; skip gracefully + raise NotImplementedError( + "Polymorphic object form fields are rendered by the view layer, not via this method." + ) + content_type = ContentType.objects.get(pk=field.related_object_type_id) has_context = False @@ -508,6 +556,8 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): return form_field def get_filterform_field(self, field, **kwargs): + if field.is_polymorphic: + return None # Filtering polymorphic fields not supported yet content_type = ContentType.objects.get(pk=field.related_object_type_id) if content_type.app_label == APP_LABEL: from netbox_custom_objects.models import CustomObjectType @@ -531,6 +581,14 @@ def render_table_column(self, value): return linkify(value) def get_serializer_field(self, field, **kwargs): + if field.is_polymorphic: + from netbox_custom_objects.api.serializers import PolymorphicObjectSerializerField + allowed_ids = {ot.id for ot in field.related_object_types.all()} + return PolymorphicObjectSerializerField( + allowed_content_type_ids=allowed_ids, + required=field.required, + allow_null=not field.required, + ) related_model_class = field.related_object_type.model_class() if related_model_class._meta.app_label == APP_LABEL: from netbox_custom_objects.api.serializers import get_serializer_class @@ -544,11 +602,109 @@ def after_model_generation(self, instance, model, field_name): Resolve lazy references after the model is fully generated. This ensures that self-referential fields point to the correct model class. """ + if instance.is_polymorphic: + return # GFK needs no post-generation resolution # Check if this field has a resolution method if resolve_method := getattr(model, f'_resolve_{field_name}_model', None): resolve_method(model) + def add_polymorphic_object_columns(self, field_instance, model, schema_editor): + """ + Add the two concrete DB columns (content_type FK + object_id) for a polymorphic + Object field, plus a composite index on both columns. + """ + ct_field_name = f"{field_instance.name}_content_type" + oid_field_name = f"{field_instance.name}_object_id" + ct_field = models.ForeignKey( + "contenttypes.ContentType", + null=True, + blank=True, + on_delete=models.SET_NULL, + related_name="+", + db_column=f"{field_instance.name}_content_type_id", + ) + ct_field.contribute_to_class(model, ct_field_name) + schema_editor.add_field(model, ct_field) + + oid_field = models.PositiveBigIntegerField(null=True, blank=True) + oid_field.contribute_to_class(model, oid_field_name) + schema_editor.add_field(model, oid_field) + + # Composite index as recommended in issue #31 + idx_name = _safe_index_name( + f"co_{field_instance.custom_object_type_id}_{field_instance.name}_gfk" + ) + idx = models.Index(fields=[ct_field_name, oid_field_name], name=idx_name) + schema_editor.add_index(model, idx) + + def remove_polymorphic_object_columns(self, field_instance, model, schema_editor): + """ + Remove the concrete DB columns for a polymorphic Object field. + + ``schema_editor`` must be supplied by the caller so that all DDL in a + single operation (e.g. field deletion) runs within one schema editor + context. Opening a nested ``with connection.schema_editor()`` here + would cause deferred_sql from the outer context to be flushed at the + wrong time on some backends. + """ + ct_field_name = f"{field_instance.name}_content_type" + oid_field_name = f"{field_instance.name}_object_id" + + try: + oid_field = model._meta.get_field(oid_field_name) + schema_editor.remove_field(model, oid_field) + except FieldDoesNotExist: + pass # Column already absent — nothing to remove. + except Exception: + logger.warning( + "Failed to remove polymorphic object_id column %r from %r", + oid_field_name, model._meta.db_table, exc_info=True, + ) + try: + ct_field = model._meta.get_field(ct_field_name) + schema_editor.remove_field(model, ct_field) + except FieldDoesNotExist: + pass # Column already absent — nothing to remove. + except Exception: + logger.warning( + "Failed to remove polymorphic content_type column %r from %r", + ct_field_name, model._meta.db_table, exc_info=True, + ) + + +# WHY CustomManyToManyManager / CustomManyToManyDescriptor / CustomManyToManyField +# exist instead of using Django's built-in ManyToManyField +# ────────────────────────────────────────────────────────────────────────────── +# Django's ManyToManyField assumes both sides of the relation are registered in +# the app registry *before* any model is instantiated. Custom object types are +# defined at runtime by end-users and their models are generated dynamically via +# `type(...)`. This creates two problems: +# +# 1. The through model does not exist in the app registry at import time, so +# Django's ManyRelatedManager cannot resolve `field.remote_field.through` +# during class construction. Attempting to register it later causes +# "model was already registered" RuntimeWarnings (suppressed in +# generate_model()) and occasional stale-cache issues. +# +# 2. Django's `get_prefetch_queryset` (and the newer `get_prefetch_querysets` +# introduced in Django 4.2) builds its result queryset from the through +# model's manager, which requires the through model to be stable in the +# registry. Because our through models are regenerated on every server +# restart (and on every schema change), the registry entry can be stale, +# causing prefetch_related() to fetch from the wrong table. +# +# CustomManyToManyManager sidesteps both issues by resolving the through model +# directly from the field instance at access time rather than from the registry, +# and by implementing get_prefetch_queryset with explicit source/target subquery +# joins that work regardless of registry state. +# +# MAINTENANCE NOTE: get_prefetch_queryset returns a private Django tuple format. +# The six-element tuple (queryset, fk_getter, rel_obj_getter, single, cache_name, +# is_descriptor) is documented only in Django internals and may change between +# major versions. If a Django upgrade breaks prefetch_related() for custom M2M +# fields, this is the first place to check. The Django source to compare against +# is django/db/models/fields/related_managers.py :: ManyRelatedManager. class CustomManyToManyManager(Manager): def __init__(self, instance=None, field_name=None): super().__init__() @@ -783,6 +939,11 @@ def get_model_field(self, field, **kwargs): """ Creates the M2M field with appropriate model references """ + if field.is_polymorphic: + # Polymorphic MultiObject: return a descriptor instead of a real M2M field. + # The descriptor manages a through table with (source_id, content_type_id, object_id). + return PolymorphicM2MDescriptor(through_model_name=field.through_model_name) + # Check if this is a self-referential M2M content_type = ContentType.objects.get(pk=field.related_object_type_id) custom_object_type_id = extract_cot_id_from_model_name(content_type.model) @@ -838,6 +999,9 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): Returns a form field for multi-object relationships. Uses DynamicModelMultipleChoiceField for both custom objects and regular NetBox objects. """ + if field.is_polymorphic: + raise NotImplementedError("Polymorphic multi-object fields are managed via the API") + content_type = ContentType.objects.get(pk=field.related_object_type_id) has_context = False @@ -883,7 +1047,34 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): form_field.widget.attrs['ts-parent-field'] = '_context' return form_field + def get_display_value(self, instance, field_name): + field = getattr(instance, field_name) + return ", ".join(str(s) for s in field.all()) + + def get_table_column_field(self, field, **kwargs): + return tables.ManyToManyColumn(linkify_item=True, orderable=False) + + def get_serializer_field(self, field, **kwargs): + if field.is_polymorphic: + from netbox_custom_objects.api.serializers import PolymorphicObjectSerializerField + from rest_framework import serializers as drf_serializers + allowed_ids = {ot.id for ot in field.related_object_types.all()} + return drf_serializers.ListField( + child=PolymorphicObjectSerializerField(allowed_content_type_ids=allowed_ids), + required=field.required, + allow_empty=True, + ) + related_model_class = field.related_object_type.model_class() + if related_model_class._meta.app_label == APP_LABEL: + from netbox_custom_objects.api.serializers import get_serializer_class + serializer = get_serializer_class(related_model_class, skip_object_fields=True) + else: + serializer = get_serializer_for_model(related_model_class) + return serializer(required=field.required, nested=True, many=True) + def get_filterform_field(self, field, **kwargs): + if field.is_polymorphic: + return None # Filtering polymorphic fields not supported yet content_type = ContentType.objects.get(pk=field.related_object_type_id) if content_type.app_label == APP_LABEL: from netbox_custom_objects.models import CustomObjectType @@ -903,26 +1094,13 @@ def get_filterform_field(self, field, **kwargs): selector=model._meta.app_label != APP_LABEL, ) - def get_display_value(self, instance, field_name): - field = getattr(instance, field_name) - return ", ".join(str(s) for s in field.all()) - - def get_table_column_field(self, field, **kwargs): - return tables.ManyToManyColumn(linkify_item=True, orderable=False) - - def get_serializer_field(self, field, **kwargs): - related_model_class = field.related_object_type.model_class() - if related_model_class._meta.app_label == APP_LABEL: - from netbox_custom_objects.api.serializers import get_serializer_class - serializer = get_serializer_class(related_model_class, skip_object_fields=True) - else: - serializer = get_serializer_for_model(related_model_class) - return serializer(required=field.required, nested=True, many=True) - def after_model_generation(self, instance, model, field_name): """ After both models are generated, update the field's remote model references """ + if instance.is_polymorphic: + return # PolymorphicM2MDescriptor needs no post-generation resolution + field = model._meta.get_field(field_name) # Skip model resolution for self-referential fields @@ -1056,6 +1234,295 @@ def create_m2m_table(self, instance, model, field_name): if table_name not in tables: schema_editor.create_model(through_model) + def get_polymorphic_through_model(self, field_instance, source_model_string): + """ + Creates a through model for a polymorphic MultiObject field. + Columns: source_id (FK to custom object), content_type_id (FK to ContentType), + object_id (PositiveBigIntegerField). + """ + meta = type( + "Meta", + (), + { + "db_table": field_instance.through_table_name, + "app_label": APP_LABEL, + "apps": apps, + "managed": True, + "unique_together": (("source", "content_type", "object_id"),), + "indexes": [ + models.Index( + fields=["content_type", "object_id"], + name=_safe_index_name( + f"co_{field_instance.custom_object_type_id}" + f"_{field_instance.name}_pgfk" + ), + ) + ], + }, + ) + + attrs = { + "__module__": "netbox_custom_objects.models", + "Meta": meta, + "id": models.AutoField(primary_key=True), + "source": models.ForeignKey( + source_model_string, + on_delete=models.CASCADE, + related_name="+", + db_column="source_id", + ), + "content_type": models.ForeignKey( + "contenttypes.ContentType", + on_delete=models.CASCADE, + related_name="+", + db_column="content_type_id", + ), + "object_id": models.PositiveBigIntegerField(db_column="object_id"), + } + + return generate_model(field_instance.through_model_name, (models.Model,), attrs) + + def create_polymorphic_m2m_table(self, field_instance, model): + """ + Creates the DB table for a polymorphic MultiObject through model. + """ + source_model_string = f"{APP_LABEL}.{model.__name__}" + through = self.get_polymorphic_through_model(field_instance, source_model_string) + + # Update source FK to point to the actual model + source_field = through._meta.get_field("source") + source_field.remote_field.model = model + source_field.related_model = model + + # Register with Django's app registry + _apps = model._meta.apps + try: + through_model = _apps.get_model(APP_LABEL, field_instance.through_model_name) + except LookupError: + _apps.register_model(APP_LABEL, through) + through_model = through + + with connection.schema_editor() as schema_editor: + table_name = through_model._meta.db_table + with connection.cursor() as cursor: + existing_tables = connection.introspection.table_names(cursor) + if table_name not in existing_tables: + schema_editor.create_model(through_model) + + def drop_polymorphic_m2m_table(self, field_instance, model, schema_editor): + """ + Drops the DB table for a polymorphic MultiObject through model. + + ``schema_editor`` must be supplied by the caller for the same reason as + ``remove_polymorphic_object_columns``: all DDL in a single operation + should share one schema editor context. + """ + _apps = model._meta.apps + try: + through_model = _apps.get_model(APP_LABEL, field_instance.through_model_name) + schema_editor.delete_model(through_model) + except LookupError: + pass # Already dropped or never created + + +class PolymorphicResultList: + """ + Lazy result returned by PolymorphicManyToManyManager.all(). + + The underlying DB queries are deferred until first access and cached + within this object's lifetime. Because PolymorphicM2MDescriptor creates + a new manager on every attribute access, and the manager's all() creates a + new PolymorphicResultList, the cache only helps *within a single call + chain* — e.g. a template that calls ``|length`` and then iterates the same + ``all()`` return value will only issue one round of queries. Calling + ``obj.poly_field.all()`` twice, however, creates two separate instances and + issues two rounds of queries. + + This is intentionally NOT a QuerySet — the objects come from multiple + model classes and cannot be combined into a single SQL result set. + It supports the subset of the list/queryset interface that templates and + common callers need: iteration, ``len()``, ``bool()``, and index access. + """ + + __slots__ = ("_factory", "_cache") + + def __init__(self, factory): + # factory is a zero-argument callable that returns an iterator of objects. + self._factory = factory + self._cache = None + + def _evaluate(self): + if self._cache is None: + self._cache = list(self._factory()) + return self._cache + + def __iter__(self): + return iter(self._evaluate()) + + def __len__(self): + return len(self._evaluate()) + + def __bool__(self): + return bool(self._evaluate()) + + def __getitem__(self, index): + return self._evaluate()[index] + + def __repr__(self): + return repr(self._evaluate()) + + +class PolymorphicManyToManyManager: + """ + Manager for polymorphic many-to-many relationships. + Handles objects from multiple model types via a through table with + (source_id, content_type_id, object_id) columns. + """ + + def __init__(self, instance, field_name, through_model_name): + self.instance = instance + self.field_name = field_name + self.through_model_name = through_model_name + + def _get_through_model(self): + return apps.get_model(APP_LABEL, self.through_model_name) + + def _get_objects(self): + through = self._get_through_model() + rows = list( + through.objects.filter(source_id=self.instance.pk) + .values_list("content_type_id", "object_id") + .order_by("id") + ) + + # Group object IDs by content type so we can batch-fetch per model class + # (one SELECT per type) rather than issuing one SELECT per row. + by_ct: dict[int, list] = {} + for ct_id, obj_id in rows: + by_ct.setdefault(ct_id, []).append(obj_id) + + # Build a lookup map: (ct_id, obj_id) → object, preserving row order below. + obj_map: dict[tuple, object] = {} + for ct_id, obj_ids in by_ct.items(): + ct = ContentType.objects.get_for_id(ct_id) + model_class = ct.model_class() + if model_class is None: + continue + for obj in model_class.objects.filter(pk__in=obj_ids): + obj_map[(ct_id, obj.pk)] = obj + + # Yield in original insertion order, skipping stale references. + for ct_id, obj_id in rows: + obj = obj_map.get((ct_id, obj_id)) + if obj is not None: + yield obj + + def all(self): + return PolymorphicResultList(self._get_objects) + + def count(self): + return self._get_through_model().objects.filter(source_id=self.instance.pk).count() + + def exists(self): + return self._get_through_model().objects.filter(source_id=self.instance.pk).exists() + + def add(self, *objs): + through = self._get_through_model() + for obj in objs: + ct = ContentType.objects.get_for_model(obj) + through.objects.get_or_create( + source_id=self.instance.pk, + content_type_id=ct.pk, + object_id=obj.pk, + ) + + def remove(self, *objs): + through = self._get_through_model() + for obj in objs: + ct = ContentType.objects.get_for_model(obj) + through.objects.filter( + source_id=self.instance.pk, + content_type_id=ct.pk, + object_id=obj.pk, + ).delete() + + def clear(self): + self._get_through_model().objects.filter(source_id=self.instance.pk).delete() + + def set(self, objs, clear=False): + if clear: + self.clear() + self.add(*objs) + else: + # Diff-based replacement: add new, remove old. Matches Django's + # standard ManyRelatedManager.set(clear=False) behaviour. + objs = tuple(objs) + through = self._get_through_model() + existing = { + (ct_id, obj_id) + for ct_id, obj_id in through.objects.filter(source_id=self.instance.pk) + .values_list("content_type_id", "object_id") + } + # Pre-compute (ct_id, obj_pk) once per object to avoid duplicate CT lookups. + new_items = [ + (ContentType.objects.get_for_model(obj).pk, obj.pk, obj) for obj in objs + ] + new_keys = {(ct_id, obj_pk) for ct_id, obj_pk, _ in new_items} + to_add = [obj for ct_id, obj_pk, obj in new_items if (ct_id, obj_pk) not in existing] + to_remove = existing - new_keys + if to_add: + self.add(*to_add) + for ct_id, obj_id in to_remove: + through.objects.filter( + source_id=self.instance.pk, + content_type_id=ct_id, + object_id=obj_id, + ).delete() + + def __iter__(self): + return iter(self.all()) + + +class PolymorphicM2MDescriptor: + """ + Descriptor for polymorphic many-to-many fields. + Added directly to the model's class attrs during model generation. + """ + + def __init__(self, through_model_name): + self.through_model_name = through_model_name + self.field_name = None + + def __set_name__(self, owner, name): + self.field_name = name + + def contribute_to_class(self, cls, name): + self.field_name = name + setattr(cls, name, self) + + def __get__(self, instance, owner=None): + if instance is None: + return self + return PolymorphicManyToManyManager( + instance=instance, + field_name=self.field_name, + through_model_name=self.through_model_name, + ) + + def __set__(self, instance, value): + raise AttributeError( + f"Direct assignment to '{self.field_name}' is not supported. " + f"Use '{self.field_name}.set(objs)' to update polymorphic M2M fields." + ) + + @property + def many_to_many(self): + return True + + @property + def concrete(self): + return False + FIELD_TYPE_CLASS = { CustomFieldTypeChoices.TYPE_TEXT: TextFieldType, diff --git a/netbox_custom_objects/forms.py b/netbox_custom_objects/forms.py index 0642ca34..df4f8e1b 100644 --- a/netbox_custom_objects/forms.py +++ b/netbox_custom_objects/forms.py @@ -5,8 +5,10 @@ from netbox.forms import (NetBoxModelBulkEditForm, NetBoxModelFilterSetForm, NetBoxModelForm, NetBoxModelImportForm) from utilities.forms.fields import (CommentField, ContentTypeChoiceField, + ContentTypeMultipleChoiceField, DynamicModelChoiceField, SlugField, TagFilterField) from utilities.forms.rendering import FieldSet +from utilities.forms.utils import get_field_value from utilities.object_types import object_type_name from netbox_custom_objects.choices import SearchWeightChoices @@ -121,6 +123,25 @@ def label_from_instance(self, obj): return super().label_from_instance(obj) +class CustomContentTypeMultipleChoiceField(ContentTypeMultipleChoiceField): + """Multi-select version of CustomContentTypeChoiceField for polymorphic object fields.""" + + def label_from_instance(self, obj): + if obj.app_label == APP_LABEL: + custom_object_type_id = extract_cot_id_from_model_name(obj.model) + if custom_object_type_id is not None: + try: + return CustomObjectType.get_content_type_label( + custom_object_type_id + ) + except CustomObjectType.DoesNotExist: + pass + try: + return object_type_name(obj) + except AttributeError: + return super().label_from_instance(obj) + + class CustomObjectTypeFieldForm(CustomFieldForm): # This field should be removed or at least "required" should be defeated object_types = forms.CharField( @@ -136,7 +157,17 @@ class CustomObjectTypeFieldForm(CustomFieldForm): related_object_type = CustomContentTypeChoiceField( label=_("Related object type"), queryset=CustomObjectObjectType.objects.public(), - help_text=_("Type of the related object (for object/multi-object fields only)"), + required=False, + help_text=_("Type of the related object (for non-polymorphic object/multi-object fields)"), + ) + related_object_types = CustomContentTypeMultipleChoiceField( + label=_("Related object types"), + queryset=CustomObjectObjectType.objects.public(), + required=False, + help_text=_( + "Allowed object types for a polymorphic field (select one or more). " + "Only used when 'Polymorphic' is enabled." + ), ) search_weight = forms.ChoiceField( choices=SearchWeightChoices, @@ -162,6 +193,13 @@ class CustomObjectTypeFieldForm(CustomFieldForm): "default", name=_("Field"), ), + FieldSet( + "is_polymorphic", + "related_object_type", + "related_object_types", + "related_object_filter", + name=_("Related Object"), + ), FieldSet( "search_weight", "filter_logic", @@ -180,14 +218,74 @@ class Meta: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Disable changing the custom object type or related object type of a field + # Toggling the polymorphic checkbox should re-render the form so only the + # relevant related-object field is shown. + self.fields['is_polymorphic'].widget.attrs.update({ + 'hx-get': '.', + 'hx-include': '#form_fields', + 'hx-target': '#form_fields', + }) + + # Determine current field type and polymorphic state. + # For existing instances is_polymorphic cannot be changed, so read it from the + # instance directly; for new fields use whatever the form currently carries. + field_type = get_field_value(self, 'type') + if self.instance.pk: + is_polymorphic = self.instance.is_polymorphic + elif self.is_bound: + # get_field_value() falls back to initial for BooleanField (no valid_value); + # read the submitted checkbox value from self.data directly instead. + is_polymorphic = bool(self.data.get('is_polymorphic')) + else: + is_polymorphic = bool(get_field_value(self, 'is_polymorphic')) + + # Show only the relevant related-object field and rebuild fieldsets cleanly. + # The parent __init__ inserts a simple FieldSet('related_object_type', ...) for + # object/multiobject types, which would create a duplicate section; replacing + # self.fieldsets here keeps a single "Related Object" group. + if field_type in (CustomFieldTypeChoices.TYPE_OBJECT, CustomFieldTypeChoices.TYPE_MULTIOBJECT): + if is_polymorphic: + if 'related_object_type' in self.fields: + del self.fields['related_object_type'] + related_obj_fields = ('is_polymorphic', 'related_object_types', 'related_object_filter') + else: + if 'related_object_types' in self.fields: + del self.fields['related_object_types'] + related_obj_fields = ('is_polymorphic', 'related_object_type', 'related_object_filter') + self.fieldsets = ( + CustomObjectTypeFieldForm.fieldsets[0], + FieldSet(*related_obj_fields, name=_('Related Object')), + CustomObjectTypeFieldForm.fieldsets[2], + ) + else: + # Parent already removed related_object_type/related_object_filter; + # remove the remaining related-object fields too. + for fname in ('related_object_types', 'is_polymorphic'): + if fname in self.fields: + del self.fields[fname] + # Drop the Related Object fieldset entirely so no empty section header renders. + # Filter by checking that every item in a fieldset belongs to the related-object + # field set (handles both our full FieldSet and any parent-inserted simple one). + _related_names = frozenset({ + 'is_polymorphic', 'related_object_type', 'related_object_types', 'related_object_filter', + }) + self.fieldsets = tuple( + fs for fs in self.fieldsets + if not all(isinstance(item, str) and item in _related_names for item in fs.items) + ) + + # Disable immutable fields on existing instances. if self.instance.pk: self.fields["custom_object_type"].disabled = True - if "related_object_type" in self.fields: + if 'is_polymorphic' in self.fields: + self.fields["is_polymorphic"].disabled = True + if 'related_object_types' in self.fields: + self.fields["related_object_types"].disabled = True + if 'related_object_type' in self.fields: self.fields["related_object_type"].disabled = True # Multi-object fields may not be set unique - if self.initial["type"] == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + if field_type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: self.fields["unique"].disabled = True # Add related_name to the Related Object fieldset for object/multiobject fields. @@ -203,26 +301,40 @@ def __init__(self, *args, **kwargs): else: del self.fields["related_name"] + def clean(self): + cleaned_data = super().clean() + field_type = cleaned_data.get("type") + is_polymorphic = cleaned_data.get("is_polymorphic", False) + + if field_type in ( + CustomFieldTypeChoices.TYPE_OBJECT, + CustomFieldTypeChoices.TYPE_MULTIOBJECT, + ) and is_polymorphic: + related_object_types = cleaned_data.get("related_object_types") + if not related_object_types: + self.add_error( + "related_object_types", + _("Polymorphic object fields must specify at least one related object type."), + ) + + return cleaned_data + def clean_primary(self): primary_fields = self.cleaned_data["custom_object_type"].fields.filter( primary=True ) if self.cleaned_data["primary"]: primary_fields.update(primary=False) - # It should be possible to have NO primary fields set on an object, and thus for its name to appear - # as the default of e.g. "Cat 1"; therefore don't try to guarantee that a primary is set - # else: - # if self.instance: - # other_primary_fields = primary_fields.exclude(pk=self.instance.id) - # else: - # other_primary_fields = primary_fields - # if not other_primary_fields.exists(): - # return True return self.cleaned_data["primary"] def save(self, commit=True): obj = super().save(commit=commit) - if obj.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT and obj.default: + # For polymorphic multiobject fields, skip default value propagation + if ( + not obj.is_polymorphic + and obj.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT + and obj.default + ): qs = obj.related_object_type.model_class().objects.filter( pk__in=obj.default ) diff --git a/netbox_custom_objects/migrations/0007_polymorphic_object_fields.py b/netbox_custom_objects/migrations/0007_polymorphic_object_fields.py new file mode 100644 index 00000000..abff30ca --- /dev/null +++ b/netbox_custom_objects/migrations/0007_polymorphic_object_fields.py @@ -0,0 +1,38 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("core", "0018_concrete_objecttype"), + ("netbox_custom_objects", "0006_customobjecttypefield_related_name_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="customobjecttypefield", + name="is_polymorphic", + field=models.BooleanField( + default=False, + verbose_name="polymorphic", + help_text=( + "When enabled, this field uses a generic foreign key and may reference " + "objects of multiple types. Set the allowed types in 'Related object types'." + ), + ), + ), + migrations.AddField( + model_name="customobjecttypefield", + name="related_object_types", + field=models.ManyToManyField( + blank=True, + related_name="polymorphic_custom_object_type_fields", + to="core.objecttype", + verbose_name="related object types", + help_text=( + "The types of objects this polymorphic field may reference " + "(used when 'Polymorphic' is enabled)." + ), + ), + ), + ] diff --git a/netbox_custom_objects/models.py b/netbox_custom_objects/models.py index aabb2927..c2009f4f 100644 --- a/netbox_custom_objects/models.py +++ b/netbox_custom_objects/models.py @@ -1,4 +1,5 @@ import decimal +import logging import re import threading from datetime import date, datetime @@ -16,7 +17,7 @@ from django.db import connection, IntegrityError, models, transaction from django.db.models import Q from django.db.models.functions import Lower -from django.db.models.signals import pre_delete, post_save +from django.db.models.signals import m2m_changed, pre_delete, post_save from django.dispatch import receiver from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -55,7 +56,9 @@ from netbox_custom_objects.constants import APP_LABEL, RESERVED_FIELD_NAMES from netbox_custom_objects.field_types import FIELD_TYPE_CLASS from netbox_custom_objects.jobs import ReindexCustomObjectTypeJob -from netbox_custom_objects.utilities import _suppress_clear_cache, generate_model +from netbox_custom_objects.utilities import _suppress_clear_cache, extract_cot_id_from_model_name, generate_model + +logger = logging.getLogger(__name__) class UniquenessConstraintTestError(Exception): @@ -64,6 +67,11 @@ class UniquenessConstraintTestError(Exception): pass +def _table_exists(table_name): + """Return True if *table_name* exists in the current database.""" + return table_name in connection.introspection.table_names() + + USER_TABLE_DATABASE_NAME_PREFIX = "custom_objects_" @@ -378,9 +386,13 @@ def _fetch_and_generate_field_attrs( field_type = FIELD_TYPE_CLASS[field.type]() field_name = field.name - field_attrs[field.name] = field_type.get_model_field( - field, - ) + model_field = field_type.get_model_field(field) + + if isinstance(model_field, dict): + # Polymorphic Object field: dict of {attr_name: field_or_descriptor} + field_attrs.update(model_field) + else: + field_attrs[field.name] = model_field # Add to field objects only if the field was successfully generated field_attrs["_field_objects"][field.id] = { @@ -410,6 +422,7 @@ def _after_model_generation(self, attrs, model): for field_object in all_field_objects.values(): field_name = field_object["name"] + field_instance = field_object["field"] # Skip fields that were skipped due to recursion if field_name in skipped_fields: @@ -418,25 +431,59 @@ def _after_model_generation(self, attrs, model): # Only process fields that actually exist on the model # Fields might be skipped due to recursion prevention if hasattr(model._meta, 'get_field'): + # For polymorphic Object fields, check the content_type sub-field instead + if field_instance.is_polymorphic: + if field_instance.type == CustomFieldTypeChoices.TYPE_OBJECT: + # Collect through models is not applicable; GFK has no through + # The GFK itself needs no resolution + pass + elif field_instance.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + # Ensure the polymorphic through model is in the app registry. + # On server restart the registry is cleared; re-register if needed. + _apps = model._meta.apps + try: + through_model = _apps.get_model(APP_LABEL, field_instance.through_model_name) + # Always update source FK to point to the current model class. + # get_model() may be called multiple times (e.g. cache invalidation + # after a field save changes cache_timestamp). Without this update + # the through model's source FK would keep pointing at the old class, + # causing Django's Collector to raise ValueError during cascade delete: + # "Cannot query 'X': Must be 'TableYModel' instance." + source_field = through_model._meta.get_field("source") + source_field.remote_field.model = model + source_field.related_model = model + except LookupError: + field_type_obj = FIELD_TYPE_CLASS[CustomFieldTypeChoices.TYPE_MULTIOBJECT]() + source_model_str = f"{APP_LABEL}.{model.__name__}" + through_model = field_type_obj.get_polymorphic_through_model( + field_instance, source_model_str + ) + source_field = through_model._meta.get_field("source") + source_field.remote_field.model = model + source_field.related_model = model + _apps.register_model(APP_LABEL, through_model) + if through_model and through_model not in through_models: + through_models.append(through_model) + continue + try: field = model._meta.get_field(field_name) - # Field exists, process it - field_object["type"].after_model_generation( - field_object["field"], model, field_name - ) + except FieldDoesNotExist: + # Field skipped during generation (e.g. due to recursion guard). + continue - # Collect through models from M2M fields - if hasattr(field, 'remote_field') and hasattr(field.remote_field, 'through'): - through_model = field.remote_field.through - # Only collect custom through models, not auto-created Django ones - if (through_model and through_model not in through_models and - hasattr(through_model._meta, 'app_label') and - through_model._meta.app_label == APP_LABEL): - through_models.append(through_model) + field_object["type"].after_model_generation( + field_object["field"], model, field_name + ) - except Exception: - # Field doesn't exist (likely skipped due to recursion), skip processing - continue + # Collect through models from M2M fields + if hasattr(field, 'remote_field') and hasattr(field.remote_field, 'through'): + through_model = field.remote_field.through + # Only collect custom through models, not auto-created Django ones + if (through_model and through_model not in through_models and + hasattr(through_model._meta, 'app_label') and + through_model._meta.app_label == APP_LABEL): + through_models.append(through_model) # Store through models on the model for yielding in get_models() model._through_models = through_models @@ -644,6 +691,12 @@ def _ensure_field_fk_constraint(self, model, field_name): try: model_field = model._meta.get_field(field_name) except Exception: + logger.warning( + "_ensure_field_fk_constraint: could not get field %r on model %r; " + "FK constraint will not be created.", + field_name, model.__name__, + exc_info=True, + ) return if not (hasattr(model_field, 'remote_field') and model_field.remote_field): @@ -729,10 +782,18 @@ def delete(self, *args, **kwargs): model = self.get_model() - # Delete all CustomObjectTypeFields that reference this CustomObjectType + # Delete all CustomObjectTypeFields that reference this CustomObjectType (non-polymorphic) for field in CustomObjectTypeField.objects.filter(related_object_type=self.object_type): field.delete() + # Handle polymorphic fields that include this CustomObjectType among their allowed types + for field in CustomObjectTypeField.objects.filter( + is_polymorphic=True, related_object_types=self.object_type + ): + field.related_object_types.remove(self.object_type) + if not field.related_object_types.exists(): + field.delete() + object_type = ObjectType.objects.get_for_model(model) ObjectChange.objects.filter(changed_object_type=object_type).delete() super().delete(*args, **kwargs) @@ -742,6 +803,11 @@ def delete(self, *args, **kwargs): pre_delete.disconnect(handle_deleted_object) object_type.delete() with connection.schema_editor() as schema_editor: + # Drop polymorphic through tables first (they have FKs to django_content_type + # and to the main table, so they must be dropped before the main table). + for through_model in getattr(model, '_through_models', []): + if _table_exists(through_model._meta.db_table): + schema_editor.delete_model(through_model) schema_editor.delete_model(model) # Unregister the model and its through-models from Django's app registry so @@ -814,7 +880,22 @@ class CustomObjectTypeField(CloningMixin, ExportTemplatesMixin, ChangeLoggedMode on_delete=models.PROTECT, blank=True, null=True, - help_text=_("The type of NetBox object this field maps to (for object fields)"), + help_text=_("The type of NetBox object this field maps to (for non-polymorphic object fields)"), + ) + is_polymorphic = models.BooleanField( + default=False, + verbose_name=_("polymorphic"), + help_text=_( + "When enabled, this field uses a generic foreign key and may reference objects of multiple types. " + "Set the allowed types in 'Related object types'." + ), + ) + related_object_types = models.ManyToManyField( + to="core.ObjectType", + blank=True, + related_name="polymorphic_custom_object_type_fields", + verbose_name=_("related object types"), + help_text=_("The types of objects this polymorphic field may reference (used when 'Polymorphic' is enabled)."), ) name = models.CharField( verbose_name=_("name"), @@ -1010,12 +1091,15 @@ def __init__(self, *args, **kwargs): self._original_name = self.name self._original_type = self.type self._original_related_object_type_id = self.related_object_type_id + self._original_is_polymorphic = self.__dict__.get("is_polymorphic", False) def __str__(self): return self.label or self.name.replace("_", " ").capitalize() @property def model_class(self): + if self.is_polymorphic: + raise ValueError("Polymorphic fields reference multiple model classes; use related_object_types instead.") return apps.get_model( self.related_object_type.app_label, self.related_object_type.model ) @@ -1053,10 +1137,23 @@ def choices(self): @property def related_object_type_label(self): + if self.is_polymorphic: + labels = [] + for ot in self.related_object_types.all(): + if ot.app_label == APP_LABEL: + cot_id = extract_cot_id_from_model_name(ot.model) + if cot_id is not None: + try: + labels.append(CustomObjectType.get_content_type_label(cot_id)) + continue + except CustomObjectType.DoesNotExist: + pass + labels.append(object_type_name(ot, include_app=True)) + return ", ".join(labels) if labels else "—" + if not self.related_object_type: + return "—" if self.related_object_type.app_label == APP_LABEL: - custom_object_type_id = self.related_object_type.model.replace( - "table", "" - ).replace("model", "") + custom_object_type_id = extract_cot_id_from_model_name(self.related_object_type.model) return CustomObjectType.get_content_type_label(custom_object_type_id) return object_type_name(self.related_object_type, include_app=True) @@ -1198,14 +1295,27 @@ def clean(self): CustomFieldTypeChoices.TYPE_OBJECT, CustomFieldTypeChoices.TYPE_MULTIOBJECT, ): - if not self.related_object_type: - raise ValidationError( - { - "related_object_type": _( - "Object fields must define an object type." - ) - } - ) + if self.is_polymorphic: + # For polymorphic fields, related_object_type must be null + if self.related_object_type: + raise ValidationError( + { + "related_object_type": _( + "Polymorphic object fields must not define a single object type; " + "use 'Related object types' instead." + ) + } + ) + # related_object_types validation happens in forms (M2M set after save) + else: + if not self.related_object_type: + raise ValidationError( + { + "related_object_type": _( + "Object fields must define an object type." + ) + } + ) elif self.related_object_type: raise ValidationError( { @@ -1214,6 +1324,14 @@ def clean(self): ) } ) + elif self.is_polymorphic: + raise ValidationError( + { + "is_polymorphic": _( + "Only Object and Multi-Object fields may be polymorphic." + ) + } + ) # Related object filter can be set only for object-type fields, and must contain a dictionary mapping (if set) if self.related_object_filter is not None: @@ -1237,6 +1355,36 @@ def clean(self): } ) + # Prevent flipping is_polymorphic on an existing field. The DB schema + # (concrete GFK columns or through table) was created for the original value; + # changing it would leave the schema in an inconsistent state. + if self.pk and bool(self.is_polymorphic) != bool(self._original_is_polymorphic): + raise ValidationError( + {"is_polymorphic": _("Cannot change the polymorphic flag after field creation.")} + ) + + # Prevent renaming a polymorphic field. + # + # For a polymorphic GFK field the concrete DB columns are named + # "{name}_content_type" and "{name}_object_id"; for a polymorphic + # MultiObject field the through table is named + # "custom_objects_{cot_id}_{name}". The save() path currently has no + # logic to rename these artefacts (it falls through to `pass`), so + # allowing a rename would silently leave the DB schema out of sync with + # the field name stored in the row — causing query failures at runtime. + # + # Until explicit rename logic is implemented (renaming the GFK columns + # and/or the through table analogously to the non-polymorphic rename path + # at save() line ~1749), we reject renames outright. + if ( + self.pk + and (self.is_polymorphic or self._original_is_polymorphic) + and self.name != self._original_name + ): + raise ValidationError( + {"name": _("Cannot rename a polymorphic field after creation.")} + ) + # related_name can only be set for object-type fields if self.related_name and self.type not in ( CustomFieldTypeChoices.TYPE_OBJECT, @@ -1250,6 +1398,18 @@ def clean(self): } ) + # related_name is not supported on polymorphic fields: GenericForeignKey ignores it + # and PolymorphicM2MDescriptor never consumes it, so any value set here would be silently + # dropped with no working reverse accessor. + if self.related_name and self.is_polymorphic: + raise ValidationError( + { + "related_name": _( + "Reverse relation names are not supported for polymorphic fields." + ) + } + ) + # related_name must be unique per related_object_type (when set) if self.related_name and self.related_object_type_id: conflict = CustomObjectTypeField.objects.filter( @@ -1270,8 +1430,10 @@ def clean(self): } ) - # Check for recursion in object and multiobject fields - if (self.type in ( + # Check for recursion in object and multiobject fields (non-polymorphic only). + # Polymorphic fields' allowed types are a M2M set after save(), so their recursion + # check runs in the check_polymorphic_recursion m2m_changed signal handler instead. + if (not self.is_polymorphic and self.type in ( CustomFieldTypeChoices.TYPE_OBJECT, CustomFieldTypeChoices.TYPE_MULTIOBJECT, ) and self.related_object_type_id and @@ -1323,7 +1485,21 @@ def _has_circular_reference(self, custom_object_type, visited): # Add this type to visited set visited.add(custom_object_type.id) - # Check all object and multiobject fields in this custom object type + # Check all *non-polymorphic* object and multiobject fields in this COT. + # + # KNOWN LIMITATION: polymorphic fields (is_polymorphic=True) store their + # allowed target types on the related_object_types M2M, not on the + # related_object_type FK. This DFS therefore does not traverse edges + # introduced by polymorphic fields. A cycle that passes entirely through + # polymorphic legs (e.g. A →(poly) B →(poly) A) will go undetected. + # + # Fixing this requires also iterating field.related_object_types.filter( + # app_label=APP_LABEL) and recursing into each. The check_polymorphic_recursion + # signal already guards the direct A→B assignment, but cannot see multi-hop + # cycles that depend on polymorphic fields already on intermediate types. + # + # TODO: extend this DFS to also traverse polymorphic related_object_types + # so that multi-hop polymorphic cycles are detected at assignment time. related_objects_checked = set() for field in custom_object_type.fields.filter( type__in=[ @@ -1363,8 +1539,17 @@ def serialize(self, value): ): return value.isoformat() if self.type == CustomFieldTypeChoices.TYPE_OBJECT: + if self.is_polymorphic and value is not None: + ct = ContentType.objects.get_for_model(value) + return {"content_type_id": ct.pk, "object_id": value.pk} return value.pk if self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + if self.is_polymorphic: + result = [] + for obj in value: + ct = ContentType.objects.get_for_model(obj) + result.append({"content_type_id": ct.pk, "object_id": obj.pk}) + return result or None return [obj.pk for obj in value] or None return value @@ -1385,9 +1570,34 @@ def deserialize(self, value): except ValueError: return value if self.type == CustomFieldTypeChoices.TYPE_OBJECT: + if self.is_polymorphic and isinstance(value, dict): + try: + ct = ContentType.objects.get(pk=value["content_type_id"]) + model = ct.model_class() + return model.objects.filter(pk=value["object_id"]).first() if model else None + except (ContentType.DoesNotExist, KeyError): + return None + if not self.related_object_type: + return None model = self.related_object_type.model_class() return model.objects.filter(pk=value).first() if self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + if self.is_polymorphic and isinstance(value, list): + results = [] + for item in value: + if isinstance(item, dict): + try: + ct = ContentType.objects.get(pk=item["content_type_id"]) + model = ct.model_class() + if model: + obj = model.objects.filter(pk=item["object_id"]).first() + if obj: + results.append(obj) + except (ContentType.DoesNotExist, KeyError): + pass + return results + if not self.related_object_type: + return [] model = self.related_object_type.model_class() return model.objects.filter(pk__in=value) return value @@ -1578,7 +1788,18 @@ def validate(self, value): # Validate selected object elif self.type == CustomFieldTypeChoices.TYPE_OBJECT: - if type(value) is not int: + if self.is_polymorphic: + # Polymorphic value is {"content_type_id": int, "object_id": int} + if not isinstance(value, dict) or not isinstance( + value.get("content_type_id"), int + ) or not isinstance(value.get("object_id"), int): + raise ValidationError( + _( + "Polymorphic object value must be a dict with integer " + "content_type_id and object_id keys, not {type}." + ).format(type=type(value).__name__) + ) + elif type(value) is not int: raise ValidationError( _("Value must be an object ID, not {type}").format( type=type(value).__name__ @@ -1594,7 +1815,18 @@ def validate(self, value): ) ) for id in value: - if type(id) is not int: + if self.is_polymorphic: + # Each polymorphic entry is {"content_type_id": int, "object_id": int} + if not isinstance(id, dict) or not isinstance( + id.get("content_type_id"), int + ) or not isinstance(id.get("object_id"), int): + raise ValidationError( + _( + "Each polymorphic multiobject value must be a dict with " + "integer content_type_id and object_id keys." + ) + ) + elif type(id) is not int: raise ValidationError( _("Found invalid object ID: {id}").format(id=id) ) @@ -1619,126 +1851,165 @@ def original(self): @property def through_table_name(self): + # STABILITY CONTRACT — do not change this formula without a data migration. + # + # The table name is computed from (custom_object_type_id, name) and is + # never stored in the database. It is used as the physical PostgreSQL + # table name for polymorphic M2M through tables and as part of the + # in-memory Django model name returned by through_model_name. + # + # Consequences of changing the formula: + # • Existing through tables in live databases would be orphaned (the + # new name would not match any table on disk). + # • Any serialised reference to the through model (e.g. in cached app + # state or migration history) would become unresolvable. + # + # If the formula must change, write a data migration that renames every + # affected table with ALTER TABLE … RENAME TO before deploying the new + # code, and update through_model_name to match. return f"custom_objects_{self.custom_object_type_id}_{self.name}" @property def through_model_name(self): + # Derived directly from through_table_name; see its stability contract above. + # The "Through_" prefix ensures the in-memory model name is unique within + # the app registry and does not collide with user-visible model names. return f"Through_{self.through_table_name}" def save(self, *args, **kwargs): is_new = self._state.adding field_type = FIELD_TYPE_CLASS[self.type]() - model_field = field_type.get_model_field(self) model = self.custom_object_type.get_model() - model_field.contribute_to_class(model, self.name) with connection.schema_editor() as schema_editor: if self._state.adding: - schema_editor.add_field(model, model_field) - if self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: - field_type.create_m2m_table(self, model, self.name) + if self.is_polymorphic: + # Polymorphic Object: add content_type + object_id columns + index + # Polymorphic MultiObject: create through table with content_type + object_id + if self.type == CustomFieldTypeChoices.TYPE_OBJECT: + field_type.add_polymorphic_object_columns(self, model, schema_editor) + elif self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + field_type.create_polymorphic_m2m_table(self, model) + else: + model_field = field_type.get_model_field(self) + model_field.contribute_to_class(model, self.name) + schema_editor.add_field(model, model_field) + if self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + field_type.create_m2m_table(self, model, self.name) else: - old_field = field_type.get_model_field(self.original) - old_field.contribute_to_class(model, self._original_name) - - # Special handling for MultiObject fields when the name changes - if ( - self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT - and self.name != self._original_name - ): - # For renamed MultiObject fields, we just need to rename the through table - old_through_table_name = self.original.through_table_name - new_through_table_name = self.through_table_name - - # Check if old through table exists - with connection.cursor() as cursor: - tables = connection.introspection.table_names(cursor) - old_table_exists = old_through_table_name in tables - - if old_table_exists: - # Create temporary models to represent the old and new through table states - old_through_meta = type( - "Meta", - (), - { - "db_table": old_through_table_name, - "app_label": APP_LABEL, - "managed": True, - }, - ) - old_through_model = generate_model( - f"TempOld{self.original.through_model_name}", - (models.Model,), - { - "__module__": "netbox_custom_objects.models", - "Meta": old_through_meta, - "id": models.AutoField(primary_key=True), - "source": models.ForeignKey( - model, - on_delete=models.CASCADE, - db_column="source_id", - related_name="+", - ), - "target": models.ForeignKey( - model, - on_delete=models.CASCADE, - db_column="target_id", - related_name="+", - ), - }, + # Polymorphic fields: renames and type changes are rejected by clean(). + # Non-schema attributes (label, description, …) may still change here. + # If clean() was bypassed and a rename slipped through, raise rather + # than silently leaving DB columns / through table out of sync. + if self.is_polymorphic or self._original_is_polymorphic: + if self.name != self._original_name: + raise ValidationError( + {"name": _("Cannot rename a polymorphic field after creation.")} ) + else: + old_field = field_type.get_model_field(self.original) + old_field.contribute_to_class(model, self._original_name) + + # Special handling for MultiObject fields when the name changes + if ( + self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT + and self.name != self._original_name + ): + # For renamed MultiObject fields, we just need to rename the through table + old_through_table_name = self.original.through_table_name + new_through_table_name = self.through_table_name + + # Check if old through table exists + with connection.cursor() as cursor: + tables = connection.introspection.table_names(cursor) + old_table_exists = old_through_table_name in tables + + if old_table_exists: + # Create temporary models to represent the old and new through table states + old_through_meta = type( + "Meta", + (), + { + "db_table": old_through_table_name, + "app_label": APP_LABEL, + "managed": True, + }, + ) + _old_through_model = generate_model( + f"TempOld{self.original.through_model_name}", + (models.Model,), + { + "__module__": "netbox_custom_objects.models", + "Meta": old_through_meta, + "id": models.AutoField(primary_key=True), + "source": models.ForeignKey( + model, + on_delete=models.CASCADE, + db_column="source_id", + related_name="+", + ), + "target": models.ForeignKey( + model, + on_delete=models.CASCADE, + db_column="target_id", + related_name="+", + ), + }, + ) - new_through_meta = type( - "Meta", - (), - { - "db_table": new_through_table_name, - "app_label": APP_LABEL, - "managed": True, - }, - ) - new_through_model = generate_model( - f"TempNew{self.through_model_name}", - (models.Model,), - { - "__module__": "netbox_custom_objects.models", - "Meta": new_through_meta, - "id": models.AutoField(primary_key=True), - "source": models.ForeignKey( - model, - on_delete=models.CASCADE, - db_column="source_id", - related_name="+", - ), - "target": models.ForeignKey( - model, - on_delete=models.CASCADE, - db_column="target_id", - related_name="+", - ), - }, - ) - new_through_model # To silence ruff error + new_through_meta = type( + "Meta", + (), + { + "db_table": new_through_table_name, + "app_label": APP_LABEL, + "managed": True, + }, + ) + new_through_model = generate_model( + f"TempNew{self.through_model_name}", + (models.Model,), + { + "__module__": "netbox_custom_objects.models", + "Meta": new_through_meta, + "id": models.AutoField(primary_key=True), + "source": models.ForeignKey( + model, + on_delete=models.CASCADE, + db_column="source_id", + related_name="+", + ), + "target": models.ForeignKey( + model, + on_delete=models.CASCADE, + db_column="target_id", + related_name="+", + ), + }, + ) + # Rename the table using Django's schema editor. + # new_through_model is passed as the first argument so Django + # can rename associated sequences (e.g. on PostgreSQL). + schema_editor.alter_db_table( + new_through_model, + old_through_table_name, + new_through_table_name, + ) + else: + # No old table exists, create the new through table + field_type.create_m2m_table(self, model, self.name) - # Rename the table using Django's schema editor - schema_editor.alter_db_table( - old_through_model, - old_through_table_name, - new_through_table_name, - ) + # Alter the field normally (this updates the field definition) + schema_editor.alter_field(model, old_field, model_field) else: - # No old table exists, create the new through table - field_type.create_m2m_table(self, model, self.name) - - # Alter the field normally (this updates the field definition) - schema_editor.alter_field(model, old_field, model_field) - else: - # Normal field alteration - schema_editor.alter_field(model, old_field, model_field) + # Normal field alteration + model_field = field_type.get_model_field(self) + model_field.contribute_to_class(model, self.name) + schema_editor.alter_field(model, old_field, model_field) - # Ensure FK constraints are properly created for OBJECT fields with CASCADE behavior + # Ensure FK constraints are properly created for non-polymorphic OBJECT fields should_ensure_fk = False - if self.type == CustomFieldTypeChoices.TYPE_OBJECT: + if self.type == CustomFieldTypeChoices.TYPE_OBJECT and not self.is_polymorphic: if self._state.adding: should_ensure_fk = True else: @@ -1782,16 +2053,23 @@ def ensure_constraint(): def delete(self, *args, **kwargs): field_type = FIELD_TYPE_CLASS[self.type]() - model_field = field_type.get_model_field(self) model = self.custom_object_type.get_model() - model_field.contribute_to_class(model, self.name) with connection.schema_editor() as schema_editor: - if self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: - apps = model._meta.apps - through_model = apps.get_model(APP_LABEL, self.through_model_name) - schema_editor.delete_model(through_model) - schema_editor.remove_field(model, model_field) + if self.is_polymorphic: + if self.type == CustomFieldTypeChoices.TYPE_OBJECT: + field_type.remove_polymorphic_object_columns(self, model, schema_editor) + elif self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + field_type.drop_polymorphic_m2m_table(self, model, schema_editor) + else: + model_field = field_type.get_model_field(self) + model_field.contribute_to_class(model, self.name) + + if self.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + _apps = model._meta.apps + through_model = _apps.get_model(APP_LABEL, self.through_model_name) + schema_editor.delete_model(through_model) + schema_editor.remove_field(model, model_field) # Clear the model cache for this CustomObjectType when a field is deleted self.custom_object_type.clear_model_cache(self.custom_object_type.id) @@ -1851,6 +2129,38 @@ def clear_cache_on_custom_object_type_save(sender, instance, **kwargs): CustomObjectType.clear_model_cache(instance.id) +@receiver(m2m_changed, sender=CustomObjectTypeField.related_object_types.through) +def check_polymorphic_recursion(sender, instance, action, pk_set, **kwargs): + """ + Prevent circular references in polymorphic field allowed-type lists. + + clean() cannot check this because related_object_types is a M2M that is set + after the instance is saved. m2m_changed fires on pre_add, which lets us abort + the operation before any rows are written. + """ + if action != "pre_add" or not pk_set: + return + + own_object_type_id = instance.custom_object_type.object_type_id + + for ot_pk in pk_set: + if ot_pk == own_object_type_id: + # Self-reference is permitted (same pattern as non-polymorphic check). + continue + try: + related_cot = CustomObjectType.objects.get(object_type_id=ot_pk) + except CustomObjectType.DoesNotExist: + continue # Native NetBox type — no COT dependency chain to traverse. + visited = {instance.custom_object_type_id} + if instance._has_circular_reference(related_cot, visited): + raise ValidationError( + _( + "Circular reference detected: one of the selected object types would " + "create a circular dependency between custom object types." + ) + ) + + @receiver(post_save, sender=CustomObjectTypeField) def clear_cache_on_field_save(sender, instance, **kwargs): """ @@ -1859,10 +2169,17 @@ def clear_cache_on_field_save(sender, instance, **kwargs): """ if instance.custom_object_type_id: CustomObjectType.clear_model_cache(instance.custom_object_type_id) + # Clear caches for non-polymorphic fields pointing to this custom object type for pointing_field in CustomObjectTypeField.objects.filter( related_object_type=instance.custom_object_type.object_type ): CustomObjectType.clear_model_cache(pointing_field.custom_object_type_id) + # Clear caches for polymorphic fields that include this custom object type + for pointing_field in CustomObjectTypeField.objects.filter( + is_polymorphic=True, + related_object_types=instance.custom_object_type.object_type, + ): + CustomObjectType.clear_model_cache(pointing_field.custom_object_type_id) @receiver(pre_delete, sender=CustomObjectTypeField) diff --git a/netbox_custom_objects/templatetags/custom_object_utils.py b/netbox_custom_objects/templatetags/custom_object_utils.py index 73042aec..b2ea681f 100644 --- a/netbox_custom_objects/templatetags/custom_object_utils.py +++ b/netbox_custom_objects/templatetags/custom_object_utils.py @@ -46,4 +46,4 @@ def get_field_is_ui_visible(obj, field: CustomObjectTypeField) -> bool: @register.filter(name="get_child_relations") def get_child_relations(obj, field: CustomObjectTypeField): - return getattr(obj, field.name).all() + return getattr(obj, field.name) diff --git a/netbox_custom_objects/tests/base.py b/netbox_custom_objects/tests/base.py index ae48fc57..a3f9db64 100644 --- a/netbox_custom_objects/tests/base.py +++ b/netbox_custom_objects/tests/base.py @@ -1,4 +1,5 @@ # Test utilities for netbox_custom_objects plugin +from django.db import connection from django.test import Client from core.models import ObjectType from extras.models import CustomFieldChoiceSet @@ -6,14 +7,108 @@ from netbox_custom_objects.models import CustomObjectType, CustomObjectTypeField +_DYNAMIC_TABLE_PREFIX = "custom_objects_" + + +def _drop_dynamic_tables(): + """Drop leftover dynamic custom-object tables and purge stale app-registry state. + + Two problems arise from --keepdb runs: + + 1. DB tables — Django's ``flush`` command uses ``django_table_names()`` which + only returns ORM-registered tables. Dynamic tables created by this plugin + live outside that registry, so ``flush`` doesn't TRUNCATE them — but they + DO have foreign keys to ``django_content_type``, causing PostgreSQL to + reject the TRUNCATE with "cannot truncate a table referenced in a foreign + key constraint". Dropping these orphan tables first fixes that. + + 2. App-registry models — stale dynamic models from prior runs may still be + registered in Django's in-process app registry even after their DB tables + are dropped. Django's deletion collector walks ``Site._meta.related_objects`` + (and similar) and queries every registered model that has a FK to the object + being deleted. If a stale model points to a now-dropped table the query + raises ``ProgrammingError``. We must deregister those models AND delete the + corresponding CustomObjectType rows from the DB before calling + ``apps.clear_cache()`` so that the next ``get_models()`` invocation rebuilds + ``_meta.related_objects`` without phantom FK references. + """ + from django.apps import apps as django_apps + from netbox_custom_objects.constants import APP_LABEL + from netbox_custom_objects.models import CustomObjectType + + # Step 1 — clear the plugin's own model cache so get_model() doesn't hand out + # stale model objects that still reference non-existent through tables. + CustomObjectType.clear_model_cache() + + # Step 2 — remove stale dynamic models from apps.all_models. + # We deliberately do NOT call apps.clear_cache() here: clear_cache() triggers + # get_models() on each AppConfig, and our override in __init__.py calls + # get_model() for every row in CustomObjectType.objects.all(). Any stale COT + # rows still in the DB would be immediately re-registered, undoing this cleanup. + app_models = django_apps.all_models.get(APP_LABEL, {}) + stale_names = [ + name for name, model in list(app_models.items()) + if hasattr(model, '_meta') and model._meta.db_table.startswith(_DYNAMIC_TABLE_PREFIX) + ] + for name in stale_names: + del app_models[name] + + # Step 3 — delete stale CustomObjectType rows via queryset (direct SQL DELETE, + # not the custom cot.delete() method which tries schema operations on tables + # that no longer exist). Wrapped in a broad except so a partially-migrated + # schema never blocks test startup. + try: + CustomObjectType.objects.all().delete() + except Exception: + pass + + # Step 4 — drop all dynamic DB tables. + all_tables = connection.introspection.table_names() + dynamic = [t for t in all_tables if t.startswith(_DYNAMIC_TABLE_PREFIX)] + if dynamic: + with connection.cursor() as cursor: + for table in dynamic: + cursor.execute(f'DROP TABLE IF EXISTS "{table}" CASCADE') + + # Step 5 — rebuild the app registry cache now that both the stale model + # entries (step 2) and the stale COT rows (step 3) are gone. get_models() + # finds no CustomObjectType rows so nothing is re-registered, and + # Site._meta.related_objects (etc.) is rebuilt without phantom FK pointers. + if stale_names: + django_apps.clear_cache() + class TransactionCleanupMixin: """Mixin for TransactionTestCase subclasses that create CustomObjectType instances. Deletes all COTs in tearDown so their backing tables are dropped before the - database flush that TransactionTestCase performs between tests. + database flush that TransactionTestCase performs between tests. Also drops + any leftover dynamic tables before the flush so a dirty database from a + previous (failed) run cannot block the TRUNCATE. + + Django 5.2 notes: + - _pre_setup / _fixture_setup are classmethods called once per class. + - _fixture_teardown is an instance method called after *every* test. + - Overriding _fixture_teardown is the correct place to drop dynamic tables + because it always runs, even when setUp raised an exception. + - Overriding _pre_setup (as classmethod) handles leftover tables from a + previous run before the first test of the current run. """ + @classmethod + def _pre_setup(cls): + # Drop leftovers from any previous (possibly failed) run first, so the + # normal fixture setup that follows isn't blocked by orphan tables. + _drop_dynamic_tables() + super()._pre_setup() + + def _fixture_teardown(self): + # Drop dynamic tables before Django's flush; without this, the flush + # command's TRUNCATE of django_content_type fails because our through + # tables have FK references to it. + _drop_dynamic_tables() + super()._fixture_teardown() + def tearDown(self): for cot in CustomObjectType.objects.all(): try: diff --git a/netbox_custom_objects/tests/test_polymorphic_fields.py b/netbox_custom_objects/tests/test_polymorphic_fields.py new file mode 100644 index 00000000..81a9e0e7 --- /dev/null +++ b/netbox_custom_objects/tests/test_polymorphic_fields.py @@ -0,0 +1,1107 @@ +""" +Tests for polymorphic GenericForeignKey field support (issue #31). + +Covers both API and UI (form) paths for: + - Polymorphic single-object (GFK) fields + - Polymorphic multi-object (through-table M2M) fields +""" +import json + +from django.test import TransactionTestCase +from django.urls import reverse +from rest_framework import status + +from core.models import ObjectType +from dcim.models import Site +from ipam.models import Prefix, IPAddress +from ipam.choices import PrefixStatusChoices +from users.models import ObjectPermission, Token + +from netbox_custom_objects.constants import APP_LABEL +from netbox_custom_objects.models import CustomObjectType, CustomObjectTypeField +from netbox_custom_objects.tests.base import CustomObjectsTestCase, TransactionCleanupMixin + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _create_token(user): + from users.choices import TokenVersionChoices + t = Token(version=TokenVersionChoices.V1, user=user) + t.save() + return t.token # plaintext for V1 tokens + + +def _grant_perm(user, action, model_class, name=None): + perm = ObjectPermission(name=name or f"poly-test-{action}", actions=[action]) + perm.save() + perm.users.add(user) + perm.object_types.add(ObjectType.objects.get_for_model(model_class)) + return perm + + +# --------------------------------------------------------------------------- +# API tests +# --------------------------------------------------------------------------- + +class PolymorphicFieldAPITest(TransactionCleanupMixin, CustomObjectsTestCase, TransactionTestCase): + """ + API tests for polymorphic Object and MultiObject fields. + Uses TransactionTestCase so that DB table creation/deletion is committed. + """ + + def setUp(self): + super().setUp() + from django.test import Client as DjangoClient + from utilities.testing import create_test_user + self.user = create_test_user("poly-api-user") + token_key = _create_token(self.user) + self.header = {"HTTP_AUTHORIZATION": f"Token {token_key}"} + # Reset client to clear the session cookie set by CustomObjectsTestCase.setUp() + # (force_login causes SessionAuthentication to take priority over TokenAuthentication) + self.client = DjangoClient() + + # Site and Prefix used as related objects + self.site = Site.objects.create(name="Test Site", slug="test-site") + self.prefix = Prefix.objects.create( + prefix="10.0.0.0/8", + status=PrefixStatusChoices.STATUS_ACTIVE, + ) + + self.site_ot = ObjectType.objects.get(app_label="dcim", model="site") + self.prefix_ot = ObjectType.objects.get(app_label="ipam", model="prefix") + + # COT with a primary text field + self.cot = CustomObjectType.objects.create( + name="PolyTest", slug="poly-test", + verbose_name_plural="Poly Tests", + ) + CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="name", label="Name", type="text", + primary=True, required=True, + ) + + # Polymorphic single-object (GFK) field + self.gfk_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="poly_obj", label="Poly Obj", type="object", + is_polymorphic=True, + ) + self.gfk_field.related_object_types.set([self.site_ot, self.prefix_ot]) + + # Polymorphic multi-object field + self.m2m_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="poly_multi", label="Poly Multi", type="multiobject", + is_polymorphic=True, + ) + self.m2m_field.related_object_types.set([self.site_ot, self.prefix_ot]) + + self.model = self.cot.get_model() + self.field_perm_ot = ObjectType.objects.get_for_model(CustomObjectTypeField) + + # --- Field creation via API --- + + def test_create_polymorphic_object_field_via_api(self): + """POSTing a new polymorphic Object field with related_object_types_input succeeds.""" + _grant_perm(self.user, "add", CustomObjectTypeField, "field-add") + url = reverse("plugins-api:netbox_custom_objects-api:customobjecttypefield-list") + data = { + "custom_object_type": self.cot.pk, + "name": "poly_obj2", + "label": "Poly Obj 2", + "type": "object", + "is_polymorphic": True, + "related_object_types_input": [ + {"app_label": "dcim", "model": "site"}, + ], + } + response = self.client.post(url, json.dumps(data), content_type="application/json", **self.header) + self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data) + created = CustomObjectTypeField.objects.get(pk=response.data["id"]) + self.assertTrue(created.is_polymorphic) + self.assertEqual(created.related_object_types.count(), 1) + + def test_create_polymorphic_multiobject_field_via_api(self): + """POSTing a new polymorphic MultiObject field succeeds.""" + _grant_perm(self.user, "add", CustomObjectTypeField, "field-add") + url = reverse("plugins-api:netbox_custom_objects-api:customobjecttypefield-list") + data = { + "custom_object_type": self.cot.pk, + "name": "poly_multi2", + "type": "multiobject", + "is_polymorphic": True, + "related_object_types_input": [ + {"app_label": "dcim", "model": "site"}, + {"app_label": "ipam", "model": "prefix"}, + ], + } + response = self.client.post(url, json.dumps(data), content_type="application/json", **self.header) + self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data) + created = CustomObjectTypeField.objects.get(pk=response.data["id"]) + self.assertTrue(created.is_polymorphic) + self.assertEqual(created.related_object_types.count(), 2) + + def test_polymorphic_field_requires_related_types(self): + """POSTing a polymorphic Object field without related_object_types_input returns 400.""" + _grant_perm(self.user, "add", CustomObjectTypeField, "field-add") + url = reverse("plugins-api:netbox_custom_objects-api:customobjecttypefield-list") + data = { + "custom_object_type": self.cot.pk, + "name": "bad_poly", + "type": "object", + "is_polymorphic": True, + } + response = self.client.post(url, json.dumps(data), content_type="application/json", **self.header) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_field_list_includes_is_polymorphic_and_related_types(self): + """GET on a polymorphic field returns is_polymorphic=True and related_object_types.""" + _grant_perm(self.user, "view", CustomObjectTypeField, "field-view") + url = reverse( + "plugins-api:netbox_custom_objects-api:customobjecttypefield-detail", + kwargs={"pk": self.gfk_field.pk}, + ) + response = self.client.get(url, **self.header) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.data["is_polymorphic"]) + self.assertEqual(len(response.data["related_object_types"]), 2) + app_models = { + (r["app_label"], r["model"]) for r in response.data["related_object_types"] + } + self.assertIn(("dcim", "site"), app_models) + self.assertIn(("ipam", "prefix"), app_models) + + # --- Immutability: is_polymorphic and related types cannot change after creation --- + + def test_patch_is_polymorphic_false_on_existing_polymorphic_field_rejected(self): + """PATCH is_polymorphic=False on an existing polymorphic field returns 400.""" + _grant_perm(self.user, "change", CustomObjectTypeField, "field-change") + url = reverse( + "plugins-api:netbox_custom_objects-api:customobjecttypefield-detail", + kwargs={"pk": self.gfk_field.pk}, + ) + response = self.client.patch( + url, + json.dumps({"is_polymorphic": False}), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.data) + self.gfk_field.refresh_from_db() + self.assertTrue(self.gfk_field.is_polymorphic) + + def test_patch_related_object_types_input_on_existing_field_rejected(self): + """PATCH related_object_types_input on an existing polymorphic field returns 400.""" + _grant_perm(self.user, "change", CustomObjectTypeField, "field-change") + url = reverse( + "plugins-api:netbox_custom_objects-api:customobjecttypefield-detail", + kwargs={"pk": self.gfk_field.pk}, + ) + response = self.client.patch( + url, + json.dumps({"related_object_types_input": [{"app_label": "dcim", "model": "site"}]}), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.data) + + def test_patch_app_label_model_on_existing_non_polymorphic_field_rejected(self): + """PATCH app_label+model on an existing non-polymorphic object field returns 400.""" + _grant_perm(self.user, "change", CustomObjectTypeField, "field-change") + # Create a non-polymorphic object field + site_ot = ObjectType.objects.get(app_label="dcim", model="site") + non_poly_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="single_obj", label="Single Obj", type="object", + is_polymorphic=False, + related_object_type=site_ot, + ) + url = reverse( + "plugins-api:netbox_custom_objects-api:customobjecttypefield-detail", + kwargs={"pk": non_poly_field.pk}, + ) + response = self.client.patch( + url, + json.dumps({"app_label": "ipam", "model": "prefix"}), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.data) + non_poly_field.refresh_from_db() + self.assertEqual(non_poly_field.related_object_type, site_ot) + + # --- Field name collision --- + + def test_create_polymorphic_field_with_duplicate_name_rejected(self): + """POST a polymorphic field whose name already exists on the same COT returns 400.""" + _grant_perm(self.user, "add", CustomObjectTypeField, "field-add-dup") + url = reverse("plugins-api:netbox_custom_objects-api:customobjecttypefield-list") + data = { + "custom_object_type": self.cot.pk, + # "poly_obj" already exists on self.cot (created in setUp) + "name": "poly_obj", + "label": "Duplicate Poly", + "type": "object", + "is_polymorphic": True, + "related_object_types_input": [ + {"app_label": "dcim", "model": "site"}, + ], + } + response = self.client.post(url, json.dumps(data), content_type="application/json", **self.header) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.data) + + def test_rename_polymorphic_field_to_collide_with_existing_field_rejected(self): + """PATCH name of a polymorphic field to an already-taken name returns 400.""" + _grant_perm(self.user, "change", CustomObjectTypeField, "field-change-dup") + url = reverse( + "plugins-api:netbox_custom_objects-api:customobjecttypefield-detail", + kwargs={"pk": self.m2m_field.pk}, + ) + # Rename poly_multi → poly_obj which is already taken + response = self.client.patch( + url, + json.dumps({"name": "poly_obj"}), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.data) + self.m2m_field.refresh_from_db() + self.assertEqual(self.m2m_field.name, "poly_multi") + + # --- Custom object CRUD with polymorphic GFK --- + + def _obj_list_url(self): + return reverse( + "plugins-api:netbox_custom_objects-api:customobject-list", + kwargs={"custom_object_type": self.cot.slug}, + ) + + def _obj_detail_url(self, pk): + return reverse( + "plugins-api:netbox_custom_objects-api:customobject-detail", + kwargs={"custom_object_type": self.cot.slug, "pk": pk}, + ) + + def test_create_custom_object_with_polymorphic_gfk_via_api(self): + """POST a custom object with a polymorphic single-object value (Site).""" + _grant_perm(self.user, "add", self.model, "co-add") + from django.contrib.contenttypes.models import ContentType + site_ct = ContentType.objects.get_for_model(Site) + data = { + "name": "gfk-test-obj", + "poly_obj": {"content_type_id": site_ct.pk, "object_id": self.site.pk}, + } + response = self.client.post( + self._obj_list_url(), json.dumps(data), content_type="application/json", **self.header + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.content) + obj = self.model.objects.get(pk=json.loads(response.content)["id"]) + self.assertEqual(obj.poly_obj, self.site) + + def test_create_custom_object_with_polymorphic_gfk_as_prefix(self): + """POST a custom object with a polymorphic single-object value (Prefix).""" + _grant_perm(self.user, "add", self.model, "co-add") + from django.contrib.contenttypes.models import ContentType + prefix_ct = ContentType.objects.get_for_model(Prefix) + data = { + "name": "gfk-prefix-obj", + "poly_obj": {"content_type_id": prefix_ct.pk, "object_id": self.prefix.pk}, + } + response = self.client.post( + self._obj_list_url(), json.dumps(data), content_type="application/json", **self.header + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.content) + obj = self.model.objects.get(pk=json.loads(response.content)["id"]) + self.assertEqual(obj.poly_obj, self.prefix) + + def test_read_custom_object_gfk_representation(self): + """GET a custom object returns polymorphic GFK with _content_type annotation.""" + _grant_perm(self.user, "view", self.model, "co-view") + obj = self.model.objects.create(name="gfk-read-obj") + obj.poly_obj = self.site + obj.save() + + response = self.client.get(self._obj_detail_url(obj.pk), **self.header) + self.assertEqual(response.status_code, status.HTTP_200_OK) + poly_data = response.data["poly_obj"] + self.assertIsNotNone(poly_data) + self.assertEqual(poly_data["_content_type"], "dcim.site") + self.assertEqual(poly_data["id"], self.site.pk) + + def test_update_custom_object_clears_gfk(self): + """PATCH with poly_obj=null clears the GFK.""" + _grant_perm(self.user, "change", self.model, "co-change") + obj = self.model.objects.create(name="gfk-clear-obj") + obj.poly_obj = self.site + obj.save() + + response = self.client.patch( + self._obj_detail_url(obj.pk), + json.dumps({"poly_obj": None}), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) + obj.refresh_from_db() + self.assertIsNone(obj.poly_obj) + + # --- Custom object CRUD with polymorphic M2M --- + + def test_create_custom_object_with_polymorphic_m2m_via_api(self): + """POST a custom object with a list of heterogeneous polymorphic M2M values.""" + _grant_perm(self.user, "add", self.model, "co-add") + from django.contrib.contenttypes.models import ContentType + site_ct = ContentType.objects.get_for_model(Site) + prefix_ct = ContentType.objects.get_for_model(Prefix) + data = { + "name": "m2m-test-obj", + "poly_multi": [ + {"content_type_id": site_ct.pk, "object_id": self.site.pk}, + {"content_type_id": prefix_ct.pk, "object_id": self.prefix.pk}, + ], + } + response = self.client.post( + self._obj_list_url(), json.dumps(data), content_type="application/json", **self.header + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.content) + obj = self.model.objects.get(pk=json.loads(response.content)["id"]) + members = obj.poly_multi.all() + self.assertIn(self.site, members) + self.assertIn(self.prefix, members) + + def test_read_custom_object_m2m_representation(self): + """GET returns poly_multi as a list of objects with _content_type.""" + _grant_perm(self.user, "view", self.model, "co-view") + obj = self.model.objects.create(name="m2m-read-obj") + obj.poly_multi.add(self.site, self.prefix) + + response = self.client.get(self._obj_detail_url(obj.pk), **self.header) + self.assertEqual(response.status_code, status.HTTP_200_OK) + poly_list = response.data["poly_multi"] + self.assertEqual(len(poly_list), 2) + content_types = {item["_content_type"] for item in poly_list} + self.assertIn("dcim.site", content_types) + self.assertIn("ipam.prefix", content_types) + + def test_update_custom_object_replaces_m2m(self): + """PATCH with a new poly_multi list replaces the existing values.""" + _grant_perm(self.user, "change", self.model, "co-change") + obj = self.model.objects.create(name="m2m-replace-obj") + obj.poly_multi.add(self.site) + + from django.contrib.contenttypes.models import ContentType + prefix_ct = ContentType.objects.get_for_model(Prefix) + response = self.client.patch( + self._obj_detail_url(obj.pk), + json.dumps({"poly_multi": [{"content_type_id": prefix_ct.pk, "object_id": self.prefix.pk}]}), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) + obj.refresh_from_db() + members = obj.poly_multi.all() + self.assertNotIn(self.site, members) + self.assertIn(self.prefix, members) + + # --- Orphaned / unresolvable content type --- + + def test_create_custom_object_with_unresolvable_content_type_rejected(self): + """POST with a content_type_id whose model_class() is None returns 400.""" + _grant_perm(self.user, "add", self.model, "co-add") + + # ObjectType is a proxy for ContentType. An entry with a nonexistent app/model + # gives a row whose model_class() returns None. Use get_or_create so the test + # is idempotent when run with --keepdb. + orphan_ot, _ = ObjectType.objects.get_or_create( + app_label="nonexistent_app", model="nonexistentmodel" + ) + # Add to allowed types so we pass the allow-list check and reach model_class(). + self.gfk_field.related_object_types.add(orphan_ot) + + data = { + "name": "orphan-ct-obj", + "poly_obj": {"content_type_id": orphan_ot.pk, "object_id": 1}, + } + response = self.client.post( + self._obj_list_url(), + json.dumps(data), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.content) + # Should return the sanitized message, not internal CT details. + self.assertNotIn(b"nonexistentmodel", response.content) + self.assertNotIn(b"nonexistent_app", response.content) + + # Remove from M2M before tearDown drops the through table; leaving the + # stale django_content_type row itself is harmless. + self.gfk_field.related_object_types.remove(orphan_ot) + + # --- Content-type enforcement tests --- + + def test_create_custom_object_with_disallowed_gfk_type_rejected(self): + """POST with poly_obj set to a disallowed content type returns 400.""" + _grant_perm(self.user, "add", self.model, "co-add") + from django.contrib.contenttypes.models import ContentType + ip_address = IPAddress.objects.create(address="192.0.2.1/24") + disallowed_ct = ContentType.objects.get_for_model(IPAddress) + data = { + "name": "gfk-disallowed-obj", + "poly_obj": {"content_type_id": disallowed_ct.pk, "object_id": ip_address.pk}, + } + response = self.client.post( + self._obj_list_url(), + json.dumps(data), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.content) + self.assertIn(b"not allowed", response.content.lower()) + + def test_create_custom_object_with_disallowed_m2m_type_rejected(self): + """POST with poly_multi containing a disallowed content type returns 400.""" + _grant_perm(self.user, "add", self.model, "co-add") + from django.contrib.contenttypes.models import ContentType + ip_address = IPAddress.objects.create(address="192.0.2.2/24") + disallowed_ct = ContentType.objects.get_for_model(IPAddress) + site_ct = ContentType.objects.get_for_model(Site) + data = { + "name": "m2m-disallowed-obj", + "poly_multi": [ + {"content_type_id": site_ct.pk, "object_id": self.site.pk}, + {"content_type_id": disallowed_ct.pk, "object_id": ip_address.pk}, + ], + } + response = self.client.post( + self._obj_list_url(), + json.dumps(data), + content_type="application/json", + **self.header, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST, response.content) + self.assertIn(b"not allowed", response.content.lower()) + + # --- DELETE --- + + def test_delete_custom_object_with_gfk_value(self): + """DELETE a custom object with a populated GFK polymorphic field returns 204 and removes the object.""" + _grant_perm(self.user, "delete", self.model, "co-delete") + from django.contrib.contenttypes.models import ContentType + site_ct = ContentType.objects.get_for_model(Site) + obj = self.model.objects.create( + name="gfk-delete-obj", + poly_obj_content_type=site_ct, + poly_obj_object_id=self.site.pk, + ) + pk = obj.pk + + response = self.client.delete(self._obj_detail_url(pk), **self.header) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT, response.content) + self.assertFalse(self.model.objects.filter(pk=pk).exists()) + + def test_delete_custom_object_with_m2m_values(self): + """ + DELETE a custom object with populated M2M polymorphic values returns 204, removes the object, + and cleans up through-table rows. + """ + from django.apps import apps as django_apps + _grant_perm(self.user, "delete", self.model, "co-delete") + obj = self.model.objects.create(name="m2m-delete-obj") + obj.poly_multi.add(self.site, self.prefix) + pk = obj.pk + + # Resolve the through model before the delete so we can verify cascade cleanup. + through_model = django_apps.get_model(APP_LABEL, self.m2m_field.through_model_name) + + response = self.client.delete(self._obj_detail_url(pk), **self.header) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT, response.content) + self.assertFalse(self.model.objects.filter(pk=pk).exists()) + # Through-table rows for this object should be gone. + self.assertFalse(through_model.objects.filter(source_id=pk).exists()) + + def test_delete_custom_object_with_empty_polymorphic_fields(self): + """DELETE a custom object with no polymorphic values set returns 204.""" + _grant_perm(self.user, "delete", self.model, "co-delete") + obj = self.model.objects.create(name="empty-poly-delete-obj") + pk = obj.pk + + response = self.client.delete(self._obj_detail_url(pk), **self.header) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT, response.content) + self.assertFalse(self.model.objects.filter(pk=pk).exists()) + + +# --------------------------------------------------------------------------- +# UI / form tests +# --------------------------------------------------------------------------- + +class PolymorphicFieldUITest(TransactionCleanupMixin, CustomObjectsTestCase, TransactionTestCase): + """ + UI tests for polymorphic fields on the custom object edit form. + """ + + def setUp(self): + super().setUp() + from utilities.testing import create_test_user + self.user = create_test_user("poly-ui-user") + self.client.force_login(self.user) + + self.site1 = Site.objects.create(name="UI Site 1", slug="ui-site-1") + self.site2 = Site.objects.create(name="UI Site 2", slug="ui-site-2") + self.prefix1 = Prefix.objects.create( + prefix="192.168.0.0/24", + status=PrefixStatusChoices.STATUS_ACTIVE, + ) + + self.site_ot = ObjectType.objects.get(app_label="dcim", model="site") + self.prefix_ot = ObjectType.objects.get(app_label="ipam", model="prefix") + + self.cot = CustomObjectType.objects.create( + name="UIPolyTest", slug="ui-poly-test", + verbose_name_plural="UI Poly Tests", + ) + CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="name", label="Name", type="text", + primary=True, required=True, + ) + self.gfk_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="poly_obj", label="Poly Obj", type="object", + is_polymorphic=True, + ) + self.gfk_field.related_object_types.set([self.site_ot, self.prefix_ot]) + + self.m2m_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="poly_multi", label="Poly Multi", type="multiobject", + is_polymorphic=True, + ) + self.m2m_field.related_object_types.set([self.site_ot, self.prefix_ot]) + + self.model = self.cot.get_model() + + # Grant the user all relevant permissions + for action in ("view", "add", "change", "delete"): + _grant_perm(self.user, action, self.model, f"ui-{action}") + _grant_perm(self.user, action, CustomObjectTypeField, f"ui-field-{action}") + # restrict_form_fields() restricts DynamicModelChoiceField querysets to objects + # the user can view; grant view on Site and Prefix so form validation passes. + _grant_perm(self.user, "view", Site, "ui-site-view") + _grant_perm(self.user, "view", Prefix, "ui-prefix-view") + + def _add_url(self): + return reverse( + "plugins:netbox_custom_objects:customobject_add", + kwargs={"custom_object_type": self.cot.slug}, + ) + + def _edit_url(self, pk): + return reverse( + "plugins:netbox_custom_objects:customobject_edit", + kwargs={"custom_object_type": self.cot.slug, "pk": pk}, + ) + + def _bulk_edit_url(self): + return reverse( + "plugins:netbox_custom_objects:customobject_bulk_edit", + kwargs={"custom_object_type": self.cot.slug}, + ) + + def _field_delete_url(self, field_pk): + return reverse( + "plugins:netbox_custom_objects:customobjecttypefield_delete", + kwargs={"pk": field_pk}, + ) + + # --- Edit form structure --- + + def test_edit_form_has_per_type_subfields_not_raw_gfk_columns(self): + """The edit form exposes per-type sub-fields, not the raw _content_type/_object_id.""" + obj = self.model.objects.create(name="form-test-obj") + response = self.client.get(self._edit_url(obj.pk)) + self.assertEqual(response.status_code, 200) + form = response.context["form"] + # Per-type sub-fields must be present + self.assertIn("poly_obj__dcim__site", form.fields) + self.assertIn("poly_obj__ipam__prefix", form.fields) + # Raw GFK columns must be excluded + self.assertNotIn("poly_obj_content_type", form.fields) + self.assertNotIn("poly_obj_object_id", form.fields) + + def test_edit_form_has_per_type_subfields_for_m2m(self): + """The edit form exposes per-type sub-fields for polymorphic MultiObject.""" + obj = self.model.objects.create(name="m2m-form-obj") + response = self.client.get(self._edit_url(obj.pk)) + self.assertEqual(response.status_code, 200) + form = response.context["form"] + self.assertIn("poly_multi__dcim__site", form.fields) + self.assertIn("poly_multi__ipam__prefix", form.fields) + + def test_edit_form_subfield_labels_are_human_readable(self): + """Sub-field labels use the field's label and the human-readable type name.""" + obj = self.model.objects.create(name="label-test-obj") + response = self.client.get(self._edit_url(obj.pk)) + form = response.context["form"] + # Should be "Poly Obj (DCIM > Site)" not "poly_obj (dcim.site)" + label = form.fields["poly_obj__dcim__site"].label + self.assertIn("Poly Obj", label) + self.assertNotIn("dcim.site", label) + + def test_edit_form_preselects_existing_gfk_value(self): + """For an existing object, the correct sub-field is pre-populated.""" + obj = self.model.objects.create(name="prefill-test") + obj.poly_obj = self.site1 + obj.save() + + response = self.client.get(self._edit_url(obj.pk)) + form = response.context["form"] + # Initial values are stored in form.initial (not on the field itself, since + # DynamicModelChoiceField sets initial via the form constructor's initial= kwarg) + site_initial = form.initial.get("poly_obj__dcim__site") + self.assertEqual(site_initial, self.site1.pk) + # Prefix sub-field initial should be empty + prefix_initial = form.initial.get("poly_obj__ipam__prefix") + self.assertFalse(prefix_initial) + + # --- Edit form submission --- + + def test_submit_edit_form_sets_gfk_to_site(self): + """POST the edit form with a Site sub-field saves the GFK to that Site.""" + obj = self.model.objects.create(name="submit-gfk-obj") + data = { + "name": "submit-gfk-obj", + "poly_obj__dcim__site": self.site1.pk, + "csrfmiddlewaretoken": "fake", + } + response = self.client.post(self._edit_url(obj.pk), data, follow=True) + self.assertNotIn(response.status_code, [400, 403, 500]) + obj.refresh_from_db() + self.assertEqual(obj.poly_obj, self.site1) + + def test_submit_edit_form_rejects_multiple_gfk_subfields(self): + """POST with more than one GFK sub-field filled returns a form error.""" + obj = self.model.objects.create(name="multi-gfk-obj") + data = { + "name": "multi-gfk-obj", + # Both sub-fields filled — should be rejected + "poly_obj__dcim__site": self.site1.pk, + "poly_obj__ipam__prefix": self.prefix1.pk, + "csrfmiddlewaretoken": "fake", + } + # Don't follow redirects: success redirects (302), validation error re-renders (200) + response = self.client.post(self._edit_url(obj.pk), data) + self.assertEqual(response.status_code, 200, "Expected form to be re-rendered with errors") + form = response.context["form"] + self.assertTrue(form.errors, "Expected form errors but found none") + # Both conflicting sub-fields should carry an error + self.assertIn("poly_obj__dcim__site", form.errors) + self.assertIn("poly_obj__ipam__prefix", form.errors) + # Object must not have been modified + obj.refresh_from_db() + self.assertIsNone(obj.poly_obj) + + def test_submit_edit_form_clears_gfk_when_no_subfield_selected(self): + """POST with no sub-field selected clears an existing GFK value.""" + obj = self.model.objects.create(name="clear-gfk-obj") + obj.poly_obj = self.site1 + obj.save() + + data = {"name": "clear-gfk-obj", "csrfmiddlewaretoken": "fake"} + response = self.client.post(self._edit_url(obj.pk), data, follow=True) + self.assertNotIn(response.status_code, [400, 403, 500]) + obj.refresh_from_db() + self.assertIsNone(obj.poly_obj) + + def test_submit_edit_form_sets_polymorphic_m2m(self): + """POST the edit form with M2M sub-fields saves values across types.""" + obj = self.model.objects.create(name="submit-m2m-obj") + data = { + "name": "submit-m2m-obj", + "poly_multi__dcim__site": [self.site1.pk, self.site2.pk], + "poly_multi__ipam__prefix": [self.prefix1.pk], + "csrfmiddlewaretoken": "fake", + } + response = self.client.post(self._edit_url(obj.pk), data, follow=True) + self.assertNotIn(response.status_code, [400, 403, 500]) + members = obj.poly_multi.all() + self.assertIn(self.site1, members) + self.assertIn(self.site2, members) + self.assertIn(self.prefix1, members) + + def test_submit_add_form_creates_object_with_polymorphic_m2m(self): + """POST the add form creates a new custom object with polymorphic M2M values.""" + data = { + "name": "add-m2m-obj", + "poly_multi__dcim__site": [self.site1.pk], + "csrfmiddlewaretoken": "fake", + } + response = self.client.post(self._add_url(), data, follow=True) + self.assertNotIn(response.status_code, [400, 403, 500]) + obj = self.model.objects.get(name="add-m2m-obj") + self.assertIn(self.site1, obj.poly_multi.all()) + + # --- Bulk edit form --- + + def test_bulk_edit_form_has_polymorphic_subfields(self): + """The bulk edit form also exposes per-type sub-fields for polymorphic fields.""" + obj1 = self.model.objects.create(name="bulk-1") + obj2 = self.model.objects.create(name="bulk-2") + # POST without _apply renders the form + response = self.client.post( + self._bulk_edit_url(), + data={"pk": [obj1.pk, obj2.pk]}, + ) + self.assertEqual(response.status_code, 200) + form = response.context["form"] + self.assertIn("poly_obj__dcim__site", form.fields) + self.assertIn("poly_multi__dcim__site", form.fields) + self.assertNotIn("poly_obj_content_type", form.fields) + + def test_bulk_edit_applies_gfk_to_all_selected_objects(self): + """Bulk edit sets the GFK field on all selected objects.""" + obj1 = self.model.objects.create(name="bulk-gfk-1") + obj2 = self.model.objects.create(name="bulk-gfk-2") + data = { + "pk": [obj1.pk, obj2.pk], + "_apply": "1", + "poly_obj__dcim__site": self.site1.pk, + "csrfmiddlewaretoken": "fake", + } + response = self.client.post(self._bulk_edit_url(), data, follow=True) + self.assertNotIn(response.status_code, [400, 403, 500]) + obj1.refresh_from_db() + obj2.refresh_from_db() + self.assertEqual(obj1.poly_obj, self.site1) + self.assertEqual(obj2.poly_obj, self.site1) + + def test_bulk_edit_applies_m2m_to_all_selected_objects(self): + """Bulk edit sets poly M2M on all selected objects.""" + obj1 = self.model.objects.create(name="bulk-m2m-1") + obj2 = self.model.objects.create(name="bulk-m2m-2") + data = { + "pk": [obj1.pk, obj2.pk], + "_apply": "1", + "poly_multi__dcim__site": [self.site1.pk], + "csrfmiddlewaretoken": "fake", + } + response = self.client.post(self._bulk_edit_url(), data, follow=True) + self.assertNotIn(response.status_code, [400, 403, 500]) + self.assertIn(self.site1, obj1.poly_multi.all()) + self.assertIn(self.site1, obj2.poly_multi.all()) + + # --- Delete confirmation for polymorphic fields --- + + def test_delete_confirmation_page_for_polymorphic_m2m_field_returns_200(self): + """GET the delete confirmation page for a polymorphic M2M field does not raise FieldError.""" + obj = self.model.objects.create(name="del-m2m-obj") + obj.poly_multi.add(self.site1) + + response = self.client.get(self._field_delete_url(self.m2m_field.pk)) + self.assertEqual(response.status_code, 200) + + def test_delete_confirmation_page_for_polymorphic_gfk_field_returns_200(self): + """GET the delete confirmation page for a polymorphic GFK field does not raise FieldError.""" + obj = self.model.objects.create(name="del-gfk-obj") + obj.poly_obj = self.site1 + obj.save() + + response = self.client.get(self._field_delete_url(self.gfk_field.pk)) + self.assertEqual(response.status_code, 200) + + +# --------------------------------------------------------------------------- +# Through-model registration +# --------------------------------------------------------------------------- + +class PolymorphicThroughModelRegistrationTest( + TransactionCleanupMixin, CustomObjectsTestCase, TransactionTestCase +): + """Verify the through model is re-registered on model regeneration (simulates restart).""" + + def setUp(self): + super().setUp() + from utilities.testing import create_test_user + self.user = create_test_user("poly-reg-user") + self.client.force_login(self.user) + + self.cot = CustomObjectType.objects.create( + name="RegTest", slug="reg-test", + verbose_name_plural="Reg Tests", + ) + CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="name", type="text", primary=True, + ) + self.m2m_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="links", type="multiobject", is_polymorphic=True, + ) + site_ot = ObjectType.objects.get(app_label="dcim", model="site") + self.m2m_field.related_object_types.set([site_ot]) + + def test_through_model_registered_after_get_model(self): + """After clearing the cache and calling get_model(), the through model is accessible.""" + from django.apps import apps as django_apps + from netbox_custom_objects.constants import APP_LABEL + + # Simulate restart by clearing cache + CustomObjectType.clear_model_cache() + + # Re-generate the model (as a request would) + self.cot.get_model() + + # The through model must be findable + through_name = self.m2m_field.through_model_name + try: + through = django_apps.get_model(APP_LABEL, through_name) + except LookupError: + self.fail( + f"Through model '{through_name}' not in app registry after get_model()." + ) + self.assertIsNotNone(through) + + +# --------------------------------------------------------------------------- +# Referenced-object deletion +# --------------------------------------------------------------------------- + +class ReferencedObjectDeletionTest( + TransactionCleanupMixin, CustomObjectsTestCase, TransactionTestCase +): + """ + Tests for the on_delete behaviour when a referenced object (e.g. a Site) or + the CustomObjectType itself is deleted. + + GFK fields use on_delete=SET_NULL on the content_type FK (so deleting the + ContentType row nulls the pointer), but there is no DB-level FK from + object_id to the concrete target table — generic relations cannot express + that. Deleting a Site therefore leaves a stale object_id; Django's GFK + accessor silently returns None in that case. + + Polymorphic M2M through-tables store (source_id, content_type_id, object_id). + The content_type FK has on_delete=CASCADE (so deleting the ContentType drops + the through-table rows), but again there is no FK from object_id to the + concrete target table. Deleting a Site leaves a stale through-table row; + PolymorphicManyToManyManager._get_objects() already skips such rows because + the batch-fetch query returns no matching object for the deleted PK. + """ + + def setUp(self): + super().setUp() + + self.site = Site.objects.create(name="Del Site", slug="del-site") + self.site_ot = ObjectType.objects.get(app_label="dcim", model="site") + self.prefix_ot = ObjectType.objects.get(app_label="ipam", model="prefix") + + self.cot = CustomObjectType.objects.create( + name="DelTest", slug="del-test", + verbose_name_plural="Del Tests", + ) + CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="name", type="text", primary=True, required=True, + ) + self.gfk_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="poly_obj", label="Poly Obj", type="object", + is_polymorphic=True, + ) + self.gfk_field.related_object_types.set([self.site_ot, self.prefix_ot]) + + self.m2m_field = CustomObjectTypeField.objects.create( + custom_object_type=self.cot, + name="poly_multi", label="Poly Multi", type="multiobject", + is_polymorphic=True, + ) + self.m2m_field.related_object_types.set([self.site_ot, self.prefix_ot]) + + self.model = self.cot.get_model() + + def test_deleting_referenced_site_nulls_gfk_accessor(self): + """ + Deleting a Site that is referenced by a polymorphic GFK field does not + raise an exception and causes the accessor to return None. + + There is no DB FK from object_id → dcim_site, so the row is not touched + at the DB level; Django's GenericForeignKey.__get__ returns None when + the target object no longer exists. + """ + obj = self.model.objects.create(name="stale-gfk") + obj.poly_obj = self.site + obj.save() + + self.site.delete() + + obj.refresh_from_db() + # The content_type column still points to the Site ContentType (site still + # exists in the app registry), but the object is gone — accessor returns None. + self.assertIsNone(obj.poly_obj) + + def test_deleting_referenced_site_leaves_stale_m2m_row_excluded_from_all(self): + """ + Deleting a Site that is in a polymorphic M2M through table leaves a stale + row in the through table (no DB FK on object_id), but all() gracefully + excludes it because the batch-fetch query returns no matching object. + """ + obj = self.model.objects.create(name="stale-m2m") + obj.poly_multi.add(self.site) + + site_pk = self.site.pk + self.site.delete() + + # Verify the stale row persists: the through table has no DB-level FK from + # object_id to dcim_site, so deleting a Site cannot cascade into the through + # table. Fetch the through model directly from the manager to avoid a + # global app-registry lookup that could resolve to a stale model from a + # prior test run when using --keepdb. + manager = obj.poly_multi # PolymorphicManyToManyManager + through = manager._get_through_model() + self.assertTrue( + through.objects.filter(source_id=obj.pk, object_id=site_pk).exists(), + "Stale through-table row should remain after target deletion (no DB FK on object_id).", + ) + + # all() skips the stale row — the result list must be empty. + self.assertEqual(list(obj.poly_multi.all()), []) + + def test_deleting_custom_object_type_drops_db_table_and_deregisters_model(self): + """ + Deleting a CustomObjectType drops its DB table, drops polymorphic through + tables, and removes the model from Django's app registry. + """ + from django.apps import apps as django_apps + from django.db import connection + from netbox_custom_objects.constants import APP_LABEL + + main_table = self.cot.get_database_table_name() + through_table = self.m2m_field.through_table_name + through_model_name = self.m2m_field.through_model_name + model_name = self.model.__name__.lower() + + # Create a row so the delete path exercises cascade logic too. + obj = self.model.objects.create(name="to-be-cascaded") + obj.poly_multi.add(self.site) + + self.cot.delete() + + with connection.cursor() as cursor: + existing_tables = connection.introspection.table_names(cursor) + + self.assertNotIn(main_table, existing_tables, "Main DB table should be dropped.") + self.assertNotIn(through_table, existing_tables, "Through table should be dropped.") + + # Model and through model must be de-registered from the app registry. + self.assertNotIn(model_name, django_apps.all_models.get(APP_LABEL, {})) + self.assertNotIn( + through_model_name.lower(), django_apps.all_models.get(APP_LABEL, {}) + ) + + +# --------------------------------------------------------------------------- +# Cycle-detection gap: multi-hop polymorphic cycles +# --------------------------------------------------------------------------- + +class PolymorphicCycleDetectionGapTest( + TransactionCleanupMixin, CustomObjectsTestCase, TransactionTestCase +): + """ + Pins the KNOWN LIMITATION in CustomObjectTypeField._has_circular_reference: + a cycle that passes entirely through polymorphic legs is NOT detected. + + The DFS in _has_circular_reference only follows non-polymorphic object/ + multiobject fields (those with related_object_type set). Polymorphic fields + store their allowed types on the related_object_types M2M and are therefore + invisible to the traversal. + + Scenario: COT-A has a polymorphic field that allows COT-B objects, and COT-B + has a polymorphic field that allows COT-A objects. This is a cycle, but the + check_polymorphic_recursion signal cannot see it because _has_circular_reference + returns False for every hop that uses a polymorphic field. + + TODO: when _has_circular_reference is extended to traverse polymorphic legs, + the assertions below should be changed to assertRaises(ValidationError) and + this docstring updated accordingly. + """ + + def setUp(self): + super().setUp() + + # COT A + self.cot_a = CustomObjectType.objects.create( + name="CycleA", slug="cycle-a", verbose_name_plural="Cycle As", + ) + CustomObjectTypeField.objects.create( + custom_object_type=self.cot_a, + name="name", type="text", primary=True, required=True, + ) + + # COT B + self.cot_b = CustomObjectType.objects.create( + name="CycleB", slug="cycle-b", verbose_name_plural="Cycle Bs", + ) + CustomObjectTypeField.objects.create( + custom_object_type=self.cot_b, + name="name", type="text", primary=True, required=True, + ) + + # Ensure both models are generated so their ObjectTypes exist in the registry. + self.cot_a.get_model() + self.cot_b.get_model() + + # Resolve the ObjectType (ContentType) for each generated model. + self.ot_a = ObjectType.objects.get_for_model(self.cot_a.get_model()) + self.ot_b = ObjectType.objects.get_for_model(self.cot_b.get_model()) + + def test_multihop_polymorphic_cycle_is_not_detected(self): + """ + KNOWN LIMITATION: A →(poly) B →(poly) A is silently accepted. + + The first set() (A allows B) succeeds because neither COT has any + non-polymorphic back-edges yet. The second set() (B allows A) also + succeeds because _has_circular_reference traverses only non-polymorphic + fields, so the A→B poly edge is invisible. + + Neither call should raise ValidationError under the current implementation. + If this test starts failing with ValidationError it means the gap has + been fixed and the test should be updated to assert the opposite. + """ + # Step 1: create a polymorphic object field on A that allows B objects. + field_a = CustomObjectTypeField.objects.create( + custom_object_type=self.cot_a, + name="link_to_b", type="object", is_polymorphic=True, + ) + try: + field_a.related_object_types.set([self.ot_b]) + except Exception as exc: + self.fail( + f"Adding COT-B as allowed type for COT-A's poly field raised " + f"{type(exc).__name__}: {exc}. Expected no error (no cycle yet)." + ) + + # Step 2: create a polymorphic object field on B that allows A objects. + # This creates a cycle A →(poly) B →(poly) A, but _has_circular_reference + # does not traverse poly legs so it goes undetected. + field_b = CustomObjectTypeField.objects.create( + custom_object_type=self.cot_b, + name="link_to_a", type="object", is_polymorphic=True, + ) + try: + field_b.related_object_types.set([self.ot_a]) + except Exception as exc: + self.fail( + f"Adding COT-A as allowed type for COT-B's poly field raised " + f"{type(exc).__name__}: {exc}. " + f"This is a known limitation — multi-hop polymorphic cycles are not " + f"detected by _has_circular_reference. If the TODO has been resolved " + f"and a ValidationError is now expected here, update this test." + ) + + # Confirm the M2M rows were actually written. + self.assertIn(self.ot_b, field_a.related_object_types.all()) + self.assertIn(self.ot_a, field_b.related_object_types.all()) diff --git a/netbox_custom_objects/views.py b/netbox_custom_objects/views.py index 352238cd..74301a9b 100644 --- a/netbox_custom_objects/views.py +++ b/netbox_custom_objects/views.py @@ -19,7 +19,9 @@ from netbox.views import generic from netbox.views.generic.mixins import TableMixin from utilities.forms import ConfirmationForm +from utilities.forms.fields import DynamicModelChoiceField, DynamicModelMultipleChoiceField from utilities.htmx import htmx_partial +from utilities.object_types import object_type_name from utilities.views import ConditionalLoginRequiredMixin, ViewTab, get_viewname, register_model_view from netbox_custom_objects.filtersets import get_filterset_class @@ -34,6 +36,59 @@ logger = logging.getLogger("netbox_custom_objects.views") +# --------------------------------------------------------------------------- +# Sub-field naming helpers for polymorphic form fields +# +# Polymorphic GFK and M2M fields are split into one form sub-field per allowed +# content type, named "{field_name}__{app_label}__{model}". These helpers +# centralise that convention so it isn't repeated across build and parse sites. +# --------------------------------------------------------------------------- + +def _poly_sub_name(field_name: str, app_label: str, model: str) -> str: + """Return the form sub-field name for one content type of a polymorphic field.""" + return f"{field_name}__{app_label}__{model}" + + +def _parse_poly_sub_name(field_name: str, sub_name: str) -> tuple[str, str]: + """Parse a polymorphic sub-field name and return (app_label, model).""" + suffix = sub_name[len(field_name) + 2:] # strip "{field_name}__" + app_label, model = suffix.split("__", 1) + return app_label, model + + +def _build_poly_subfields(field, set_initial: bool = False): + """ + Build per-type form sub-fields for a polymorphic Object or MultiObject field. + + Yields ``(sub_name, sub_field)`` pairs — one per allowed object type whose + model class can be resolved. Types whose ``model_class()`` returns ``None`` + (e.g. orphaned ContentType rows) are silently skipped. + + Args: + field: A ``CustomObjectTypeField`` instance with ``is_polymorphic=True``. + set_initial: When ``True``, sets ``sub_field.initial = None`` on each + generated field (required for bulk-edit forms). + """ + is_multi = field.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT + field_class = DynamicModelMultipleChoiceField if is_multi else DynamicModelChoiceField + field_label = field.label or field.name.replace("_", " ").title() + + for ot in field.related_object_types.all(): + sub_model = ot.model_class() + if sub_model is None: + continue + sub_name = _poly_sub_name(field.name, ot.app_label, ot.model) + sub_field = field_class( + queryset=sub_model.objects.all(), + required=False, + label=f"{field_label} ({object_type_name(ot)})", + selector=ot.app_label != APP_LABEL, + ) + if set_initial: + sub_field.initial = None + yield sub_name, sub_field + + class CustomJournalEntryForm(JournalEntryForm): """ Custom journal entry form that handles return URLs for custom objects. @@ -276,10 +331,20 @@ def get(self, request, *args, **kwargs): form = ConfirmationForm(initial=request.GET) model = obj.custom_object_type.get_model_with_serializer() - kwargs = { - f"{obj.name}__isnull": False, - } - num_dependent_objects = model.objects.filter(**kwargs).count() + if obj.is_polymorphic and obj.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + # Polymorphic M2M: count via the through table (field is a descriptor, not a real FK) + from django.apps import apps as django_apps + try: + through = django_apps.get_model(APP_LABEL, obj.through_model_name) + num_dependent_objects = through.objects.values("source_id").distinct().count() + except LookupError: + num_dependent_objects = 0 + elif obj.is_polymorphic and obj.type == CustomFieldTypeChoices.TYPE_OBJECT: + # Polymorphic Object (GFK): query via the concrete content_type column + ct_field = f"{obj.name}_content_type__isnull" + num_dependent_objects = model.objects.filter(**{ct_field: False}).count() + else: + num_dependent_objects = model.objects.filter(**{f"{obj.name}__isnull": False}).count() # If this is an HTMX request, return only the rendered deletion form as modal content if htmx_partial(request): @@ -313,10 +378,24 @@ def get(self, request, *args, **kwargs): def _get_dependent_objects(self, obj): dependent_objects = super()._get_dependent_objects(obj) model = obj.custom_object_type.get_model_with_serializer() - kwargs = { - f"{obj.name}__isnull": False, - } - dependent_objects[model] = list(model.objects.filter(**kwargs)) + + if obj.is_polymorphic and obj.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: + # Polymorphic M2M: the field is a descriptor, not a real column — query + # via the through table and return the source objects. + from django.apps import apps as django_apps + try: + through = django_apps.get_model(APP_LABEL, obj.through_model_name) + source_ids = through.objects.values_list("source_id", flat=True).distinct() + dependent_objects[model] = list(model.objects.filter(pk__in=source_ids)) + except LookupError: + dependent_objects[model] = [] + elif obj.is_polymorphic and obj.type == CustomFieldTypeChoices.TYPE_OBJECT: + # Polymorphic GFK: filter on the concrete content_type column. + ct_field = f"{obj.name}_content_type__isnull" + dependent_objects[model] = list(model.objects.filter(**{ct_field: False})) + else: + dependent_objects[model] = list(model.objects.filter(**{f"{obj.name}__isnull": False})) + return dependent_objects @@ -458,12 +537,22 @@ def get_object(self, **kwargs): return get_object_or_404(model.objects.all(), **self.kwargs) def get_form(self, model): + # Collect raw GFK column names to exclude from the auto-generated form fields. + # For each polymorphic Object field "foo", Django adds "foo_content_type" and + # "foo_object_id" as real model columns; we replace those with per-type selects. + poly_obj_raw_exclude = [] + for f in self.object.custom_object_type.fields.filter( + type=CustomFieldTypeChoices.TYPE_OBJECT, is_polymorphic=True + ): + poly_obj_raw_exclude += [f"{f.name}_content_type", f"{f.name}_object_id"] + meta = type( "Meta", (), { "model": model, "fields": "__all__", + "exclude": poly_obj_raw_exclude, }, ) @@ -473,13 +562,40 @@ def get_form(self, model): "_errors": None, "custom_object_type_fields": {}, "custom_object_type_field_groups": {}, + # Maps polymorphic M2M field name → list of sub-field names (one per allowed type) + "custom_object_type_poly_m2m_fields": {}, + # Maps polymorphic Object field name → list of sub-field names (one per allowed type) + "custom_object_type_poly_obj_fields": {}, } # Process custom object type fields (with grouping) - for field in self.object.custom_object_type.fields.all().order_by( - "group_name", "weight", "name" - ): + for field in self.object.custom_object_type.fields.prefetch_related( + 'related_object_types' + ).order_by("group_name", "weight", "name"): field_type = field_types.FIELD_TYPE_CLASS[field.type]() + group_name = field.group_name or None + + # Polymorphic object/multiobject: one form sub-field per allowed type + if field.is_polymorphic and field.type in ( + CustomFieldTypeChoices.TYPE_OBJECT, + CustomFieldTypeChoices.TYPE_MULTIOBJECT, + ): + sub_names = [] + for sub_name, sub_field in _build_poly_subfields(field): + attrs[sub_name] = sub_field + sub_names.append(sub_name) + if group_name not in attrs["custom_object_type_field_groups"]: + attrs["custom_object_type_field_groups"][group_name] = [] + attrs["custom_object_type_field_groups"][group_name].append(sub_name) + + dest_key = ( + "custom_object_type_poly_obj_fields" + if field.type == CustomFieldTypeChoices.TYPE_OBJECT + else "custom_object_type_poly_m2m_fields" + ) + attrs[dest_key][field.name] = sub_names + continue + try: field_name = field.name attrs[field_name] = field_type.get_annotated_form_field(field) @@ -488,7 +604,6 @@ def get_form(self, model): attrs["custom_object_type_fields"][field_name] = field # Group fields by group_name (similar to NetBox custom fields) - group_name = field.group_name or None # Use None for ungrouped fields if group_name not in attrs["custom_object_type_field_groups"]: attrs["custom_object_type_field_groups"][group_name] = [] attrs["custom_object_type_field_groups"][group_name].append(field_name) @@ -496,11 +611,6 @@ def get_form(self, model): except NotImplementedError: logger.debug("get_form: {} field is not supported".format(field.name)) - # Note: Regular model fields (non-custom fields) are automatically included - # by the "fields": "__all__" setting in the Meta class, so we don't need - # to manually add them to the form attributes or grouping structure. - # The template will be able to access them directly through the form. - form_class = type( f"{model._meta.object_name}Form", (forms.NetBoxModelForm,), @@ -511,22 +621,22 @@ def get_form(self, model): def custom_init(self, *args, **kwargs): # Set the grouping info as instance attributes from the outer scope self.custom_object_type_fields = attrs["custom_object_type_fields"] - self.custom_object_type_field_groups = attrs[ - "custom_object_type_field_groups" - ] + self.custom_object_type_field_groups = attrs["custom_object_type_field_groups"] + self.custom_object_type_poly_m2m_fields = attrs["custom_object_type_poly_m2m_fields"] + self.custom_object_type_poly_obj_fields = attrs["custom_object_type_poly_obj_fields"] - # Handle default values for MultiObject fields BEFORE calling parent __init__ - # This ensures the initial values are set before Django processes the form instance = kwargs.get('instance', None) + + if 'initial' not in kwargs: + kwargs['initial'] = {} + + # Set initial values for non-polymorphic MultiObject defaults on new instances if not instance or not instance.pk: - # Only set defaults for new instances (not when editing existing ones) for field_name, field_obj in self.custom_object_type_fields.items(): if field_obj.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: if field_obj.default and isinstance(field_obj.default, list): - # Get the related model content_type = field_obj.related_object_type if content_type.app_label == APP_LABEL: - # Custom object type from netbox_custom_objects.models import CustomObjectType custom_object_type_id = extract_cot_id_from_model_name(content_type.model) if custom_object_type_id is None: @@ -535,24 +645,67 @@ def custom_init(self, *args, **kwargs): f"got {content_type.model!r}" ) custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) - model = custom_object_type.get_model(skip_object_fields=True) + related_model = custom_object_type.get_model(skip_object_fields=True) else: - # Regular NetBox model - model = content_type.model_class() - + related_model = content_type.model_class() try: - # Query the database to get the actual objects - initial_objects = model.objects.filter(pk__in=field_obj.default) - # Convert to list of IDs for ModelMultipleChoiceField - initial_ids = list(initial_objects.values_list('pk', flat=True)) - - # Set the initial value in the form's initial data - if 'initial' not in kwargs: - kwargs['initial'] = {} + initial_ids = list( + related_model.objects.filter(pk__in=field_obj.default) + .values_list('pk', flat=True) + ) kwargs['initial'][field_name] = initial_ids except Exception: - # If there's an error, don't set initial values + logger.debug( + "Failed to load default initial values for field %r", + field_name, exc_info=True, + ) + + # Set initial values for polymorphic sub-fields from the existing instance + if instance and instance.pk: + from django.contrib.contenttypes.models import ContentType as CT + from django.apps import apps as django_apps + + # M2M: read through-table rows and group by content type + for field_name, sub_names in self.custom_object_type_poly_m2m_fields.items(): + try: + field_obj = instance.custom_object_type.fields.get(name=field_name) + through = django_apps.get_model(APP_LABEL, field_obj.through_model_name) + rows = through.objects.filter(source_id=instance.pk).values_list( + "content_type_id", "object_id" + ) + by_ct = {} + for ct_id, obj_id in rows: + by_ct.setdefault(ct_id, []).append(obj_id) + + for sub_name in sub_names: + app_label, model_name = _parse_poly_sub_name(field_name, sub_name) + try: + ct = CT.objects.get(app_label=app_label, model=model_name) + kwargs['initial'][sub_name] = by_ct.get(ct.pk, []) + except CT.DoesNotExist: pass + except Exception: + logger.debug( + "Failed to load polymorphic M2M initial values for field %r", + field_name, exc_info=True, + ) + + # GFK: pre-populate the matching type's sub-field + for field_name, sub_names in self.custom_object_type_poly_obj_fields.items(): + try: + gfk_value = getattr(instance, field_name, None) + if gfk_value is not None: + ct = CT.objects.get_for_model(gfk_value) + for sub_name in sub_names: + app_label, model_name = _parse_poly_sub_name(field_name, sub_name) + if ct.app_label == app_label and ct.model == model_name: + kwargs['initial'][sub_name] = gfk_value.pk + break + except Exception: + logger.debug( + "Failed to load polymorphic GFK initial value for field %r", + field_name, exc_info=True, + ) # Now call the parent __init__ with the modified kwargs forms.NetBoxModelForm.__init__(self, *args, **kwargs) @@ -565,28 +718,61 @@ def custom_save(self, commit=True): if commit: instance.save() - # Handle M2M fields manually to ensure proper clearing and setting + # Handle non-polymorphic M2M fields for field_name, field_obj in self.custom_object_type_fields.items(): if field_obj.type == CustomFieldTypeChoices.TYPE_MULTIOBJECT: - # Get the current value from the form current_value = self.cleaned_data.get(field_name, []) - - # Get the field from the instance instance_field = getattr(instance, field_name) - - # Clear existing relationships and set new ones if hasattr(instance_field, 'clear') and hasattr(instance_field, 'set'): instance_field.clear() - if current_value: instance_field.set(current_value) + # Handle polymorphic single-object sub-fields: use the first non-empty selection + for field_name, sub_names in self.custom_object_type_poly_obj_fields.items(): + chosen = None + for sub_name in sub_names: + val = self.cleaned_data.get(sub_name) + if val is not None: + chosen = val + break + setattr(instance, field_name, chosen) + if self.custom_object_type_poly_obj_fields: + instance.save() + + # Handle polymorphic M2M sub-fields: aggregate per-type selections + for field_name, sub_names in self.custom_object_type_poly_m2m_fields.items(): + combined = [] + for sub_name in sub_names: + combined.extend(self.cleaned_data.get(sub_name, [])) + instance_field = getattr(instance, field_name) + instance_field.set(combined) + # Save M2M relationships self.save_m2m() return instance + def custom_clean(self): + # Call parent for side effects (custom field processing etc.). + # CheckLastUpdatedMixin.clean() does not propagate its return value, + # so the chain returns None; read self.cleaned_data directly instead. + forms.NetBoxModelForm.clean(self) + # Enforce that at most one sub-field is filled for each polymorphic + # single-object field. Multiple non-None values are ambiguous and + # would otherwise be silently resolved by "first non-empty wins". + for field_name, sub_names in self.custom_object_type_poly_obj_fields.items(): + filled = [sn for sn in sub_names if self.cleaned_data.get(sn) is not None] + if len(filled) > 1: + for sub_name in filled: + self.add_error( + sub_name, + _("Only one type may be selected for this field — clear all but one."), + ) + return self.cleaned_data + form_class.__init__ = custom_init + form_class.clean = custom_clean form_class.save = custom_save return form_class @@ -663,28 +849,53 @@ def get_queryset(self, request): return model.objects.all() def get_form(self, queryset): + poly_obj_raw_exclude = [] + for f in self.custom_object_type.fields.filter( + type=CustomFieldTypeChoices.TYPE_OBJECT, is_polymorphic=True + ): + poly_obj_raw_exclude += [f"{f.name}_content_type", f"{f.name}_object_id"] + meta = type( "Meta", (), { "model": queryset.model, "fields": "__all__", + "exclude": poly_obj_raw_exclude, }, ) attrs = { "Meta": meta, "__module__": "database.forms", + "_poly_obj_field_map": {}, # field_name → [sub_names] + "_poly_m2m_field_map": {}, # field_name → [sub_names] } - for field in self.custom_object_type.fields.all(): + for field in self.custom_object_type.fields.prefetch_related('related_object_types').all(): field_type = field_types.FIELD_TYPE_CLASS[field.type]() + + # Polymorphic object/multiobject: one form sub-field per allowed type + if field.is_polymorphic and field.type in ( + CustomFieldTypeChoices.TYPE_OBJECT, + CustomFieldTypeChoices.TYPE_MULTIOBJECT, + ): + sub_names = [] + for sub_name, sub_field in _build_poly_subfields(field, set_initial=True): + attrs[sub_name] = sub_field + sub_names.append(sub_name) + + dest_key = ( + "_poly_obj_field_map" + if field.type == CustomFieldTypeChoices.TYPE_OBJECT + else "_poly_m2m_field_map" + ) + attrs[dest_key][field.name] = sub_names + continue + try: form_field = field_type.get_annotated_form_field(field) # In bulk edit forms, all fields should be optional and start blank. - # Setting required=False prevents validation errors; setting initial=None - # ensures has_changed() only returns True when the user explicitly sets a - # value, preventing spurious updates when the field default is non-None. form_field.required = False form_field.widget.is_required = False form_field.initial = None @@ -699,12 +910,39 @@ def get_form(self, queryset): (NetBoxModelBulkEditForm,), attrs, ) - - # Set the model attribute that NetBox form mixins expect form.model = queryset.model - return form + def post_save_operations(self, form, obj): + super().post_save_operations(form, obj) + + # Apply polymorphic single-object sub-fields (first non-empty selection wins) + needs_save = False + for field_name, sub_names in form._poly_obj_field_map.items(): + for sub_name in sub_names: + val = form.cleaned_data.get(sub_name) + if val is not None: + setattr(obj, field_name, val) + needs_save = True + break + if needs_save: + obj.save() + + # Apply polymorphic M2M sub-fields (union of all selected types). + # set() replaces existing values, matching NetBox's standard bulk-edit + # behavior for direct M2M fields (see BulkEditView lines 718-723). + # Fields left blank are skipped so existing data is preserved. + for field_name, sub_names in form._poly_m2m_field_map.items(): + combined = [] + has_any = False + for sub_name in sub_names: + vals = form.cleaned_data.get(sub_name) or [] + if vals: + has_any = True + combined.extend(vals) + if has_any: + getattr(obj, field_name).set(combined) + def get_extra_context(self, request): return { 'branch_warning': is_in_branch(),