diff --git a/src/DispatchR/Configuration/ServiceRegistrator.cs b/src/DispatchR/Configuration/ServiceRegistrator.cs index b57d506..fe807b6 100644 --- a/src/DispatchR/Configuration/ServiceRegistrator.cs +++ b/src/DispatchR/Configuration/ServiceRegistrator.cs @@ -170,20 +170,33 @@ public static void RegisterHandlers(IServiceCollection services, List allT } } - public static void RegisterNotification(IServiceCollection services, List allTypes, - Type syncNotificationHandlerType) - { - var allNotifications = allTypes - .SelectMany(handlerType => handlerType.GetInterfaces() - .Where(i => i.IsGenericType && syncNotificationHandlerType == i.GetGenericTypeDefinition()) - .Select(i => new { HandlerType = handlerType, Interface = i })) - .ToList(); - - foreach (var notification in allNotifications) - { - services.AddScoped(notification.Interface, notification.HandlerType); - } - } + public static void RegisterNotification(IServiceCollection services, List allTypes, + Type syncNotificationHandlerType) + { + var allNotifications = allTypes + .SelectMany(handlerType => handlerType.GetInterfaces() + .Where(i => i.IsGenericType && syncNotificationHandlerType == i.GetGenericTypeDefinition()) + .Select(i => new { HandlerType = handlerType, Interface = i })) + .ToList(); + + foreach (var notification in allNotifications) + { + var serviceType = notification.Interface; + var implementationType = notification.HandlerType; + + if (serviceType.ContainsGenericParameters) + { + serviceType = serviceType.IsGenericTypeDefinition + ? serviceType + : serviceType.GetGenericTypeDefinition(); + implementationType = implementationType.IsGenericTypeDefinition + ? implementationType + : implementationType.GetGenericTypeDefinition(); + } + + services.AddScoped(serviceType, implementationType); + } + } private static bool IsAwaitable(Type type) { @@ -200,4 +213,4 @@ private static bool IsAwaitable(Type type) return false; } } -} \ No newline at end of file +} diff --git a/tests/DispatchR.IntegrationTest/NotificationTests.cs b/tests/DispatchR.IntegrationTest/NotificationTests.cs index 06dd78a..9d97844 100644 --- a/tests/DispatchR.IntegrationTest/NotificationTests.cs +++ b/tests/DispatchR.IntegrationTest/NotificationTests.cs @@ -103,4 +103,77 @@ public void RegisterNotification_SingleClassWithMultipleNotificationInterfaces_R Assert.Contains(handlers1, h => h is MultiNotificationHandler); Assert.Contains(handlers2, h => h is MultiNotificationHandler); } + + [Fact] + public async Task Publish_CallsOpenGenericAndSpecificHandlers_WhenBothAreRegistered() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddDispatchR(cfg => + { + cfg.Assemblies.Add(typeof(Fixture).Assembly); + cfg.RegisterPipelines = false; + cfg.RegisterNotifications = true; + }); + var serviceProvider = services.BuildServiceProvider(); + var mediator = serviceProvider.GetRequiredService(); + var executionStore = serviceProvider.GetRequiredService(); + + // Act + await mediator.Publish(new OpenGenericTargetNotification(Guid.NewGuid()), CancellationToken.None); + + // Assert + Assert.Equal(1, executionStore.Count($"generic:{nameof(OpenGenericTargetNotification)}")); + Assert.Equal(1, executionStore.Count($"specific:{nameof(OpenGenericTargetNotification)}")); + } + + [Fact] + public async Task PublishObject_CallsOpenGenericAndSpecificHandlers_WhenBothAreRegistered() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddDispatchR(cfg => + { + cfg.Assemblies.Add(typeof(Fixture).Assembly); + cfg.RegisterPipelines = false; + cfg.RegisterNotifications = true; + }); + var serviceProvider = services.BuildServiceProvider(); + var mediator = serviceProvider.GetRequiredService(); + var executionStore = serviceProvider.GetRequiredService(); + + // Act + object notificationObject = new OpenGenericTargetNotification(Guid.NewGuid()); + await mediator.Publish(notificationObject, CancellationToken.None); + + // Assert + Assert.Equal(1, executionStore.Count($"generic:{nameof(OpenGenericTargetNotification)}")); + Assert.Equal(1, executionStore.Count($"specific:{nameof(OpenGenericTargetNotification)}")); + } + + [Fact] + public async Task Publish_CallsOpenGenericHandler_WhenNoSpecificHandlerExists() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddDispatchR(cfg => + { + cfg.Assemblies.Add(typeof(Fixture).Assembly); + cfg.RegisterPipelines = false; + cfg.RegisterNotifications = true; + }); + var serviceProvider = services.BuildServiceProvider(); + var mediator = serviceProvider.GetRequiredService(); + var executionStore = serviceProvider.GetRequiredService(); + + // Act + await mediator.Publish(new OpenGenericOnlyNotification(Guid.NewGuid()), CancellationToken.None); + + // Assert + Assert.Equal(1, executionStore.Count($"generic:{nameof(OpenGenericOnlyNotification)}")); + Assert.Equal(0, executionStore.Count($"specific:{nameof(OpenGenericOnlyNotification)}")); + } } diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericNotificationExecutionStore.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericNotificationExecutionStore.cs new file mode 100644 index 0000000..29ac22f --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericNotificationExecutionStore.cs @@ -0,0 +1,18 @@ +using System.Collections.Concurrent; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed class OpenGenericNotificationExecutionStore +{ + private readonly ConcurrentDictionary _counters = new(); + + public void Increment(string key) + { + _counters.AddOrUpdate(key, 1, (_, current) => current + 1); + } + + public int Count(string key) + { + return _counters.TryGetValue(key, out var count) ? count : 0; + } +} diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericNotificationHandler.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericNotificationHandler.cs new file mode 100644 index 0000000..404fc38 --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericNotificationHandler.cs @@ -0,0 +1,21 @@ +using DispatchR.Abstractions.Notification; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed class OpenGenericNotificationHandler : INotificationHandler + where TNotification : INotification +{ + private static readonly OpenGenericNotificationExecutionStore FallbackStore = new(); + private readonly OpenGenericNotificationExecutionStore _store; + + public OpenGenericNotificationHandler(OpenGenericNotificationExecutionStore? store = null) + { + _store = store ?? FallbackStore; + } + + public ValueTask Handle(TNotification request, CancellationToken cancellationToken) + { + _store.Increment($"generic:{typeof(TNotification).Name}"); + return ValueTask.CompletedTask; + } +} diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericOnlyNotification.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericOnlyNotification.cs new file mode 100644 index 0000000..5727b5d --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericOnlyNotification.cs @@ -0,0 +1,5 @@ +using DispatchR.Abstractions.Notification; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed record OpenGenericOnlyNotification(Guid Id) : INotification; diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericTargetNotification.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericTargetNotification.cs new file mode 100644 index 0000000..54036a0 --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericTargetNotification.cs @@ -0,0 +1,5 @@ +using DispatchR.Abstractions.Notification; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed record OpenGenericTargetNotification(Guid Id) : INotification; diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericTargetNotificationHandler.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericTargetNotificationHandler.cs new file mode 100644 index 0000000..fa3f003 --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/OpenGenericTargetNotificationHandler.cs @@ -0,0 +1,20 @@ +using DispatchR.Abstractions.Notification; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed class OpenGenericTargetNotificationHandler : INotificationHandler +{ + private static readonly OpenGenericNotificationExecutionStore FallbackStore = new(); + private readonly OpenGenericNotificationExecutionStore _store; + + public OpenGenericTargetNotificationHandler(OpenGenericNotificationExecutionStore? store = null) + { + _store = store ?? FallbackStore; + } + + public ValueTask Handle(OpenGenericTargetNotification request, CancellationToken cancellationToken) + { + _store.Increment($"specific:{nameof(OpenGenericTargetNotification)}"); + return ValueTask.CompletedTask; + } +} diff --git a/tests/DispatchR.UnitTest/AddDispatchRConfigurationTests.cs b/tests/DispatchR.UnitTest/AddDispatchRConfigurationTests.cs index cd87485..6ca79a6 100644 --- a/tests/DispatchR.UnitTest/AddDispatchRConfigurationTests.cs +++ b/tests/DispatchR.UnitTest/AddDispatchRConfigurationTests.cs @@ -1,3 +1,4 @@ +using DispatchR.Abstractions.Notification; using DispatchR.Abstractions.Stream; using DispatchR.Exceptions; using DispatchR.Extensions; @@ -237,4 +238,30 @@ p.IsKeyedService is false && Assert.Equal(3, countOfAllSimpleHandlers); } -} \ No newline at end of file + + [Fact] + public void AddDispatchR_RegisterNotifications_IncludesOpenGenericNotificationHandler() + { + // Arrange + var services = new ServiceCollection(); + + // Act + services.AddDispatchR(cfg => + { + cfg.Assemblies.Add(typeof(Fixture).Assembly); + cfg.RegisterPipelines = false; + cfg.RegisterNotifications = true; + }); + + // Assert + var openGenericHandler = services.SingleOrDefault(p => + p.IsKeyedService is false && + p.ServiceType.IsGenericTypeDefinition && + p.ServiceType == typeof(INotificationHandler<>) && + p.ImplementationType is not null && + p.ImplementationType.IsGenericTypeDefinition && + p.ImplementationType == typeof(OpenGenericNotificationHandler<>)); + + Assert.NotNull(openGenericHandler); + } +}