diff --git a/CHANGES.rst b/CHANGES.rst index 3a4993fa3..8fb280b6b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,18 @@ Changelog ========= +Version 1.2.3 [unreleased] +-------------------------- + +Bugfixes +~~~~~~~~ + +- Fixed VPN peer cache desync when VPN template is removed from a device: + ``VpnClient`` objects are now deleted via per-instance ``delete()`` so that + ``post_delete`` signals fire, ensuring peer cache invalidation, certificate + revocation, and IP address release `#1221 + `_ + Version 1.2.2 [2026-03-06] -------------------------- diff --git a/openwisp_controller/config/api/serializers.py b/openwisp_controller/config/api/serializers.py index 7e67d8f5f..adf2620a1 100644 --- a/openwisp_controller/config/api/serializers.py +++ b/openwisp_controller/config/api/serializers.py @@ -199,9 +199,15 @@ def _update_config(self, device, config_data): old_templates = list(config.templates.values_list("id", flat=True)) if config_templates != old_templates: with transaction.atomic(): - vpn_list = config.templates.filter(type="vpn").values_list("vpn") - if vpn_list: - config.vpnclient_set.exclude(vpn__in=vpn_list).delete() + new_vpn_ids = Template.objects.filter( + pk__in=config_templates, type="vpn" + ).values_list("vpn", flat=True) + for vpnclient in ( + config.vpnclient_set.select_related("vpn", "cert", "ip") + .exclude(vpn__in=new_vpn_ids) + .iterator() + ): + vpnclient.delete() config.templates.set(config_templates, clear=True) config.save() except ValidationError as error: diff --git a/openwisp_controller/config/base/config.py b/openwisp_controller/config/base/config.py index 139877344..2ecb72221 100644 --- a/openwisp_controller/config/base/config.py +++ b/openwisp_controller/config/base/config.py @@ -347,7 +347,11 @@ def manage_vpn_clients(cls, action, instance, pk_set, **kwargs): if instance.is_deactivating_or_deactivated(): # If the device is deactivated or in the process of deactivating, then # delete all vpn clients and return. - instance.vpnclient_set.all().delete() + with transaction.atomic(): + for vpnclient in instance.vpnclient_set.select_related( + "vpn", "cert", "ip" + ).iterator(): + vpnclient.delete() return vpn_client_model = cls.vpn.through @@ -379,9 +383,15 @@ def manage_vpn_clients(cls, action, instance, pk_set, **kwargs): # signal is triggered again—after all templates, including the required # ones, have been fully added. At that point, we can identify and # delete VpnClient objects not linked to the final template set. - instance.vpnclient_set.exclude( - template_id__in=instance.templates.values_list("id", flat=True) - ).delete() + with transaction.atomic(): + for vpnclient in ( + instance.vpnclient_set.select_related("vpn", "cert", "ip") + .exclude( + template_id__in=instance.templates.values_list("id", flat=True) + ) + .iterator() + ): + vpnclient.delete() if action == "post_add": for template in templates.filter(type="vpn"): diff --git a/openwisp_controller/config/tests/test_vpn.py b/openwisp_controller/config/tests/test_vpn.py index 0b9a52ad8..af1f0f0ae 100644 --- a/openwisp_controller/config/tests/test_vpn.py +++ b/openwisp_controller/config/tests/test_vpn.py @@ -287,6 +287,72 @@ def _assert_vpn_client_cert(cert, vpn_client, cert_ct, vpn_client_ct): vpnclient.save() _assert_vpn_client_cert(cert, vpnclient, 1, 0) + def test_vpn_client_post_delete_on_template_removal(self): + """Regression test for #1221: VpnClient.post_delete must fire + when a VPN template is removed so that peer cache is invalidated + and certificates are properly revoked.""" + org = self._get_org() + vpn = self._create_vpn() + t = self._create_template(name="vpn-test", type="vpn", vpn=vpn, auto_cert=True) + c = self._create_config(organization=org) + c.templates.add(t) + vpnclient = c.vpnclient_set.first() + self.assertIsNotNone(vpnclient) + cert_pk = vpnclient.cert.pk + with mock.patch.object(Vpn, "_invalidate_peer_cache") as mock_invalidate: + c.templates.remove(t) + mock_invalidate.assert_called_once() + self.assertFalse(VpnClient.objects.filter(pk=vpnclient.pk).exists()) + self.assertTrue(Cert.objects.get(pk=cert_pk).revoked) + + def test_vpn_client_post_delete_on_device_deactivation(self): + """Regression test for #1221: VpnClient.post_delete must fire + when a device is deactivated so that peer cache is invalidated + and certificates are properly revoked.""" + org = self._get_org() + vpn = self._create_vpn() + t = self._create_template(name="vpn-test", type="vpn", vpn=vpn, auto_cert=True) + d = self._create_device(organization=org) + c = self._create_config(device=d) + c.templates.add(t) + vpnclient = c.vpnclient_set.first() + self.assertIsNotNone(vpnclient) + cert_pk = vpnclient.cert.pk + with mock.patch.object(Vpn, "_invalidate_peer_cache") as mock_invalidate: + d.deactivate() + mock_invalidate.assert_called_once() + self.assertFalse(VpnClient.objects.filter(pk=vpnclient.pk).exists()) + self.assertTrue(Cert.objects.get(pk=cert_pk).revoked) + + def test_vpn_client_post_delete_multiple_clients(self): + """Regression test for #1221: when a device has multiple VPN templates, + deactivating it must delete every VpnClient, invalidate peer cache for + each VPN, and revoke all auto-created certificates.""" + org = self._get_org() + vpn1 = self._create_vpn(name="vpn1") + vpn2 = self._create_vpn(name="vpn2", ca=vpn1.ca) + t1 = self._create_template( + name="vpn-t1", type="vpn", vpn=vpn1, auto_cert=True + ) + t2 = self._create_template( + name="vpn-t2", type="vpn", vpn=vpn2, auto_cert=True + ) + d = self._create_device(organization=org) + c = self._create_config(device=d) + c.templates.add(t1, t2) + self.assertEqual(c.vpnclient_set.count(), 2) + vpnclient1 = c.vpnclient_set.get(vpn=vpn1) + vpnclient2 = c.vpnclient_set.get(vpn=vpn2) + cert_pk1 = vpnclient1.cert.pk + cert_pk2 = vpnclient2.cert.pk + with mock.patch.object(Vpn, "_invalidate_peer_cache") as mock_invalidate: + d.deactivate() + self.assertEqual(mock_invalidate.call_count, 2) + self.assertFalse(VpnClient.objects.filter(pk=vpnclient1.pk).exists()) + self.assertFalse(VpnClient.objects.filter(pk=vpnclient2.pk).exists()) + self.assertTrue(Cert.objects.get(pk=cert_pk1).revoked) + self.assertTrue(Cert.objects.get(pk=cert_pk2).revoked) + def test_vpn_client_get_common_name(self): vpn = self._create_vpn() d = self._create_device()