diff --git a/readme.md b/readme.md index 35dcdbf..55be9eb 100644 --- a/readme.md +++ b/readme.md @@ -174,6 +174,63 @@ Note you can also register the same service using multiple keys, as shown in the > [!IMPORTANT] > Keyed services are a feature of version 8.0+ of Microsoft.Extensions.DependencyInjection +### Decorating Services + +After services are registered with `AddServices`, you can wrap existing registrations +with a decorator using the `Decorate()` extension method. +The decorated service must already be registered, and the source generator replaces +matching registrations in-place while preserving each registration's lifetime. + +The decorator type must implement `TDecorated` and provide a constructor that accepts +the decorated service as one of its parameters (additional dependencies are resolved +from the container as usual). Annotating the decorator with `[Service]` is optional +(when present, lifetime compatibility with the decorated service is validated at compile time): + +```csharp +public interface INotificationService +{ + void Send(string message); +} + +[Service(ServiceLifetime.Scoped)] +public class EmailNotificationService : INotificationService +{ + public void Send(string message) => Console.WriteLine($"[Email] {message}"); +} + +[Service(ServiceLifetime.Scoped)] +public class LoggingNotificationService(INotificationService inner) : INotificationService +{ + public void Send(string message) + { + Console.WriteLine("Sending notification..."); + inner.Send(message); + } +} + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddServices(); +builder.Services.Decorate(); +``` + +When resolving `INotificationService`, the container returns a `LoggingNotificationService` +that wraps the original `EmailNotificationService`. `Func` and `Lazy` registrations +for the same service type also resolve the decorated instance. + +If multiple registrations exist for the same service type, all of them are decorated. +For keyed services, pass the key to decorate a specific registration: + +```csharp +builder.Services.AddServices(); +builder.Services.Decorate("email"); +``` + +The generator validates decorations at compile-time: the decorator must have a constructor +that accepts the decorated service type (plus any additional dependencies). If the decorator +is annotated with `[Service]`, its lifetime is also validated for compatibility with the +decorated registration(s). + ## How It Works In all cases, the generated code that implements the registration looks like the following: diff --git a/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj b/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj index b623a81..b3595bf 100644 --- a/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj +++ b/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj @@ -12,9 +12,9 @@ - - - + + + diff --git a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs new file mode 100644 index 0000000..4a10477 --- /dev/null +++ b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs @@ -0,0 +1,244 @@ +using System.Collections.Immutable; +using System.IO; +using System.Threading.Tasks; +using Devlooped.Extensions.DependencyInjection; +using Microsoft.CodeAnalysis.CSharp.Testing; +using Microsoft.CodeAnalysis.Testing; +using Xunit; +using Xunit.Abstractions; +using Verifier = Microsoft.CodeAnalysis.CSharp.Testing.CSharpAnalyzerVerifier; + +namespace Tests.CodeAnalysis; + +public class DecorateGeneratorTests(ITestOutputHelper Output) +{ + [Fact] + public async Task ErrorIfDecoratorLifetimeIsIncompatible() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service(ServiceLifetime.Scoped)] + public class Foo : IFoo { } + + [Service(ServiceLifetime.Singleton)] + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + {|#0:services.Decorate()|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorLifetimeIncompatible) + .WithLocation(0) + .WithArguments("FooDecorator", "Singleton", "IFoo", "Scoped")); + + await test.RunAsync(); + } + + [Fact] + public async Task ErrorIfDecoratorConstructorDoesNotAcceptDecoratedService() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service] + public class Foo : IFoo { } + + [Service] + public class FooDecorator : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + {|#0:services.Decorate()|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorConstructorMissing) + .WithLocation(0) + .WithArguments("FooDecorator", "IFoo")); + + await test.RunAsync(); + } + + [Fact] + public async Task NoErrorIfDecoratorConstructorHasOtherDependencies() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + public interface IOtherDependency { } + + [Service] + public class Foo : IFoo { } + + [Service] + public class OtherDependency : IOtherDependency { } + + [Service] + public class FooDecorator(IFoo inner, IOtherDependency other) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.Decorate(); + } + } + """); + + await test.RunAsync(); + } + + [Fact] + public async Task NoErrorIfDecoratorHasNoServiceAttribute() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service] + public class Foo : IFoo { } + + // Decorator intentionally has NO [Service] attribute + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.Decorate(); + } + } + """); + + // No diagnostics expected — decorator no longer requires [Service] + await test.RunAsync(); + } + + [Fact] + public async Task NoErrorIfKeyedDecoratorLifetimeMatchesSelectedKey() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service("foo", ServiceLifetime.Scoped)] + public class Foo : IFoo { } + + [Service("bar", ServiceLifetime.Singleton)] + public class OtherFoo : IFoo { } + + [Service("foo", ServiceLifetime.Scoped)] + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.Decorate("foo"); + } + } + """); + + await test.RunAsync(); + } + + [Fact] + public async Task ErrorIfKeyedDecoratorLifetimeIsIncompatible() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service("foo", ServiceLifetime.Scoped)] + public class Foo : IFoo { } + + [Service("foo", ServiceLifetime.Singleton)] + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + {|#0:services.Decorate("foo")|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorLifetimeIncompatible) + .WithLocation(0) + .WithArguments("FooDecorator", "Singleton", "IFoo", "Scoped")); + + await test.RunAsync(); + } + + static CSharpSourceGeneratorTest CreateTest(string source) + { + return new CSharpSourceGeneratorTest + { + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + TestCode = source, + TestState = + { + AnalyzerConfigFiles = + { + ("/.editorconfig", + """ + is_global = true + build_property.AddServicesExtension = true + """) + }, + Sources = + { + ThisAssembly.Resources.AddServicesNoReflectionExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text + }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) + .AddPackages(ImmutableArray.Create( + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) + }, + }.WithPreprocessorSymbols(); + } +} diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index 1a8acba..7a1dbe0 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -10,6 +10,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.Extensions.DependencyInjection; +using DecoratedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TDecorated, Microsoft.CodeAnalysis.INamedTypeSymbol TDecorator, bool IsKeyed, bool HasKeyValue, object? KeyValue, Microsoft.CodeAnalysis.Location? Location); using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TImplementation, Microsoft.CodeAnalysis.INamedTypeSymbol? TService, Microsoft.CodeAnalysis.TypedConstant? Key); namespace Devlooped.Extensions.DependencyInjection; @@ -30,6 +31,24 @@ public class IncrementalGenerator : IIncrementalGenerator DiagnosticSeverity.Warning, isEnabledByDefault: true); + public static DiagnosticDescriptor DecoratorLifetimeIncompatible { get; } = + new DiagnosticDescriptor( + "DDI007", + "Decorator lifetime is incompatible with decorated services.", + "Decorator type {0} has lifetime {1}, which is incompatible with decorated service {2} lifetimes {3}.", + "Build", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static DiagnosticDescriptor DecoratorConstructorMissing { get; } = + new DiagnosticDescriptor( + "DDI008", + "Decorator constructor must accept the decorated service.", + "Decorator type {0} must have an accessible constructor with exactly one parameter of type {1}; additional dependencies are allowed.", + "Build", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + class ServiceSymbol(INamedTypeSymbol implementation, int lifetime, TypedConstant? key, Location? location, INamedTypeSymbol? service) { public INamedTypeSymbol TImplementation => implementation; @@ -66,6 +85,8 @@ record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullN public Regex Regex => (regex ??= FullNameExpression is not null ? new(FullNameExpression) : new(".*")); } + record ServiceAttributeInfo(int Lifetime, TypedConstant? Key, INamedTypeSymbol? ServiceType, Location? Location); + public void Initialize(IncrementalGeneratorInitializationContext context) { var types = context.CompilationProvider.Combine(context.AnalyzerConfigOptionsProvider).SelectMany((x, c) => @@ -92,28 +113,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return visitor.TypeSymbols.Where(t => !t.IsAbstract && t.TypeKind == TypeKind.Class); }); - bool IsService(AttributeData attr) => - (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service") && - attr.ConstructorArguments.Length == 1 && - attr.ConstructorArguments[0].Kind == TypedConstantKind.Enum && - attr.ConstructorArguments[0].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; - - bool IsKeyedService(AttributeData attr) => - (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service" || - attr.AttributeClass?.Name == "KeyedService" || attr.AttributeClass?.Name == "KeyedServiceAttribute") && - //attr.AttributeClass?.IsGenericType == true && - attr.ConstructorArguments.Length == 2 && - attr.ConstructorArguments[1].Kind == TypedConstantKind.Enum && - attr.ConstructorArguments[1].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; - - bool IsExport(AttributeData attr) - { - var attrName = attr.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - return attrName == "global::System.Composition.ExportAttribute" || - attrName == "global::System.ComponentModel.Composition.ExportAttribute"; - } - ; - // NOTE: we recognize the attribute by name, not precise type. This makes the generator // more flexible and avoids requiring any sort of run-time dependency. @@ -129,8 +128,8 @@ bool IsExport(AttributeData attr) foreach (var attr in attrs) { - var serviceAttr = IsService(attr) || IsKeyedService(attr) ? attr : null; - if (serviceAttr == null && !IsExport(attr)) + var serviceAttr = IsServiceAttribute(attr) || IsKeyedServiceAttribute(attr) ? attr : null; + if (serviceAttr == null && !IsExportAttribute(attr)) continue; TypedConstant? key = default; @@ -139,7 +138,7 @@ bool IsExport(AttributeData attr) var lifetime = serviceAttr != null ? 0 : 2; if (serviceAttr != null) { - if (IsKeyedService(serviceAttr)) + if (IsKeyedServiceAttribute(serviceAttr)) { key = serviceAttr.ConstructorArguments[0]; lifetime = (int)serviceAttr.ConstructorArguments[1].Value!; @@ -149,7 +148,7 @@ bool IsExport(AttributeData attr) lifetime = (int)serviceAttr.ConstructorArguments[0].Value!; } } - else if (IsExport(attr)) + else if (IsExportAttribute(attr)) { // In NuGet MEF, [Shared] makes exports singleton if (attrs.Any(a => a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.Composition.SharedAttribute")) @@ -216,6 +215,20 @@ serviceAttrSymbol.Symbol is IMethodSymbol attrCtor && .Select((x, _) => x.Left) .Collect(); + var decorations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is InvocationExpressionSyntax invocation && GetInvokedMethodName(invocation) == "Decorate", + transform: static (ctx, cancellation) => GetDecoration((InvocationExpressionSyntax)ctx.Node, ctx.SemanticModel, cancellation)) + .Combine(context.AnalyzerConfigOptionsProvider) + .Where(x => + { + (var decoration, var options) = x; + return options.GlobalOptions.TryGetValue("build_property.AddServicesExtension", out var value) && + bool.TryParse(value, out var addServices) && addServices && decoration is not null; + }) + .Select((x, _) => x.Left!.Value) + .Collect(); + // Project matching service types to register with the given lifetime. var conventionServices = types.Combine(methodInvocations.Combine(context.CompilationProvider)).SelectMany((pair, cancellationToken) => { @@ -245,38 +258,54 @@ serviceAttrSymbol.Symbol is IMethodSymbol attrCtor && .SelectMany((tuple, _) => ImmutableArray.CreateRange([tuple.Item1, tuple.Item2])) .SelectMany((items, _) => items.Distinct().ToImmutableArray()); - RegisterServicesOutput(context, finalServices, context.CompilationProvider); + RegisterServicesOutput(context, finalServices, decorations, context.CompilationProvider); + RegisterDecorateOutput(context, decorations, finalServices.Collect(), context.CompilationProvider); } - void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider services, IncrementalValueProvider compilation) + void RegisterServicesOutput( + IncrementalGeneratorInitializationContext context, + IncrementalValuesProvider services, + IncrementalValueProvider> decorations, + IncrementalValueProvider compilation) { context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 0 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddSingleton", ctx, data)); + services.Where(x => x!.Lifetime == 0 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddSingleton", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 1 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddScoped", ctx, data)); + services.Where(x => x!.Lifetime == 1 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddScoped", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 2 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddTransient", ctx, data)); + services.Where(x => x!.Lifetime == 2 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddTransient", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 0 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddKeyedSingleton", ctx, data)); + services.Where(x => x!.Lifetime == 0 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddKeyedSingleton", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 1 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddKeyedScoped", ctx, data)); + services.Where(x => x!.Lifetime == 1 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddKeyedScoped", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddKeyedTransient", ctx, data)); + services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddKeyedTransient", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput(services.Collect(), ReportInconsistencies); } + void RegisterDecorateOutput( + IncrementalGeneratorInitializationContext context, + IncrementalValueProvider> decorations, + IncrementalValueProvider> services, + IncrementalValueProvider compilation) + { + context.RegisterImplementationSourceOutput( + decorations.Combine(services).Combine(compilation), + (ctx, data) => AddDecoratePartial(ctx, data.Left.Left, data.Left.Right, data.Right)); + } + void ReportInconsistencies(SourceProductionContext context, ImmutableArray array) { var grouped = array.GroupBy(x => x.TImplementation, SymbolEqualityComparer.Default).Where(g => g.Count() > 1).ToImmutableArray(); @@ -304,6 +333,313 @@ void ReportInconsistencies(SourceProductionContext context, ImmutableArray decorations, + ImmutableArray services, + Compilation compilation) + { + if (decorations.IsEmpty) + return; + + var validDecorations = ImmutableArray.CreateBuilder<(DecoratedService Decoration, IMethodSymbol Constructor)>(); + + foreach (var decoration in decorations) + { + if (!ValidateDecoration(ctx, decoration, services, compilation, out var constructor)) + continue; + + validDecorations.Add((decoration, constructor!)); + } + + if (validDecorations.Count == 0) + return; + + var builder = new StringBuilder() + .AppendLine("// ") + .AppendLine("#nullable enable"); + + foreach (var alias in compilation.References.SelectMany(r => r.Properties.Aliases)) + { + builder.AppendLine($"extern alias {alias};"); + } + + builder.AppendLine( + """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + static partial class AddServicesNoReflectionExtension + { + static partial void DecorateServices(IServiceCollection services) + where TDecorated : class + where TDecorator : class, TDecorated + { + """); + + // Emit non-keyed dispatch + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, _) = validDecorations[i]; + if (decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + + builder.AppendLine($" if (typeof(TDecorated) == typeof({decorated}) && typeof(TDecorator) == typeof({decorator}))"); + builder.AppendLine(" {"); + builder.AppendLine($" DecorateDescriptors<{decorated}, {decorator}>(services, CreateDecorator{i});"); + builder.AppendLine(" return;"); + builder.AppendLine(" }"); + } + + builder.AppendLine( + """ + } + + static partial void DecorateKeyedServices(IServiceCollection services, object? key) + where TDecorated : class + where TDecorator : class, TDecorated + { + """); + + // Emit keyed dispatch + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, _) = validDecorations[i]; + if (!decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + + builder.AppendLine($" if (typeof(TDecorated) == typeof({decorated}) && typeof(TDecorator) == typeof({decorator}))"); + builder.AppendLine(" {"); + builder.AppendLine($" DecorateDescriptors<{decorated}, {decorator}>(services, key, CreateKeyedDecorator{i});"); + builder.AppendLine(" return;"); + builder.AppendLine(" }"); + } + + builder.AppendLine( + """ + } + """); + + // Helper to build constructor arguments for a decorator factory. + // Handles the special case for the decorated service parameter and [FromKeyedServices]. + string BuildDecoratorArgs((DecoratedService Decoration, IMethodSymbol Constructor) entry, Compilation compilation, bool isKeyed) + { + var (decoration, ctor) = entry; + var decorated = decoration.TDecorated.ToFullName(compilation); + bool usedDecorated = false; + + return string.Join(", ", ctor.Parameters.Select(p => + { + if (!usedDecorated && SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) + { + usedDecorated = true; + return isKeyed + ? $"GetKeyedDecorated<{decorated}>(s, key, descriptor)" + : $"GetDecorated<{decorated}>(s, descriptor)"; + } + + var fromKeyed = p.GetAttributes().FirstOrDefault(IsFromKeyed); + if (fromKeyed is not null) + return $"s.GetRequiredKeyedService<{p.Type.ToFullName(compilation)}>({fromKeyed.ConstructorArguments[0].ToCSharpString()})"; + + return $"s.GetRequiredService<{p.Type.ToFullName(compilation)}>()"; + })); + } + + // Emit non-keyed factory methods + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, ctor) = validDecorations[i]; + if (decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + var args = BuildDecoratorArgs((decoration, ctor), compilation, isKeyed: false); + + builder.AppendLine(); + builder.AppendLine($" static {decorated} CreateDecorator{i}(IServiceProvider s, ServiceDescriptor descriptor)"); + builder.AppendLine($" => new {decorator}({args});"); + } + + // Emit keyed factory methods + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, ctor) = validDecorations[i]; + if (!decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + var args = BuildDecoratorArgs((decoration, ctor), compilation, isKeyed: true); + + builder.AppendLine(); + builder.AppendLine($" static {decorated} CreateKeyedDecorator{i}(IServiceProvider s, object? key, ServiceDescriptor descriptor)"); + builder.AppendLine($" => new {decorator}({args});"); + } + + builder.AppendLine( + """ + } + } + """); + + ctx.AddSource("Decorate.g", builder.ToString().Replace("\r\n", "\n").Replace("\n", Environment.NewLine)); + } + + bool ValidateDecoration( + SourceProductionContext ctx, + DecoratedService decoration, + ImmutableArray services, + Compilation compilation, + out IMethodSymbol? constructor) + { + constructor = GetDecoratorConstructor(decoration, compilation); + var isValid = true; + var decoratorLifetimes = GetDecoratorLifetimes(decoration, compilation); + + var decoratedLifetimes = GetDecoratedLifetimes(decoration, services, compilation); + if (!decoratorLifetimes.IsEmpty && !decoratedLifetimes.IsEmpty && + (decoratorLifetimes.Length != 1 || decoratedLifetimes.Any(x => x != decoratorLifetimes[0]))) + { + ctx.ReportDiagnostic(Diagnostic.Create( + DecoratorLifetimeIncompatible, + decoration.Location, + decoration.TDecorator.ToDisplayString(), + string.Join(", ", decoratorLifetimes.Select(LifetimeName)), + decoration.TDecorated.ToDisplayString(), + string.Join(", ", decoratedLifetimes.Select(LifetimeName)))); + isValid = false; + } + + if (constructor is null) + { + ctx.ReportDiagnostic(Diagnostic.Create( + DecoratorConstructorMissing, + decoration.Location, + decoration.TDecorator.ToDisplayString(), + decoration.TDecorated.ToDisplayString())); + isValid = false; + } + + return isValid; + } + + static ImmutableArray GetDecoratorLifetimes(DecoratedService decoration, Compilation compilation) + { + return GetServiceAttributes(decoration.TDecorator) + .Where(x => IsMatchingDecoratorAttribute(decoration, x)) + .Where(x => + x.ServiceType is null ? + compilation.HasImplicitConversion(decoration.TDecorator, decoration.TDecorated) : + SymbolEqualityComparer.Default.Equals(x.ServiceType, decoration.TDecorated)) + .Select(x => x.Lifetime) + .Distinct() + .ToImmutableArray(); + } + + static ImmutableArray GetDecoratedLifetimes(DecoratedService decoration, ImmutableArray services, Compilation compilation) + { + return services + .Where(x => decoration.IsKeyed ? x.Key is not null : x.Key is null) + .Where(x => !decoration.IsKeyed || !decoration.HasKeyValue || Equals(x.Key?.Value, decoration.KeyValue)) + .Where(x => !SymbolEqualityComparer.Default.Equals(x.TImplementation, decoration.TDecorator)) + .Where(x => + x.TService is null ? + compilation.HasImplicitConversion(x.TImplementation, decoration.TDecorated) : + SymbolEqualityComparer.Default.Equals(x.TService, decoration.TDecorated)) + .Select(x => x.Lifetime) + .Distinct() + .ToImmutableArray(); + } + + static bool IsMatchingDecoratorAttribute(DecoratedService decoration, ServiceAttributeInfo attribute) => + decoration.IsKeyed || attribute.Key is null; + + IMethodSymbol? GetDecoratorConstructor(DecoratedService decoration, Compilation compilation) + { + var candidates = decoration.TDecorator.InstanceConstructors + .Where(x => compilation.IsSymbolAccessible(x)) + .Where(x => x.Parameters.Count(p => SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) == 1) + .ToImmutableArray(); + + if (candidates.IsDefaultOrEmpty) + return null; + + return candidates.FirstOrDefault(HasImportingConstructor) ?? + candidates.OrderByDescending(m => m.Parameters.Length).FirstOrDefault(); + } + + static bool HasImportingConstructor(IMethodSymbol method) => + method.GetAttributes().Any(a => + a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.Composition.ImportingConstructorAttribute" || + a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.ComponentModel.Composition.ImportingConstructorAttribute"); + + static bool IsDecoratorServiceAlias(INamedTypeSymbol implementation, INamedTypeSymbol service, bool keyed, ImmutableArray decorations) => + decorations.Any(x => + x.IsKeyed == keyed && + SymbolEqualityComparer.Default.Equals(x.TDecorator, implementation) && + SymbolEqualityComparer.Default.Equals(x.TDecorated, service)); + + static ImmutableArray GetServiceAttributes(INamedTypeSymbol type) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var attr in type.GetAttributes()) + { + if (!IsServiceAttribute(attr) && !IsKeyedServiceAttribute(attr)) + continue; + + var lifetime = IsKeyedServiceAttribute(attr) ? + (int)attr.ConstructorArguments[1].Value! : + (int)attr.ConstructorArguments[0].Value!; + var key = IsKeyedServiceAttribute(attr) ? attr.ConstructorArguments[0] : (TypedConstant?)null; + var serviceType = attr.AttributeClass?.IsGenericType == true && + attr.AttributeClass.TypeArguments.Length == 1 && + attr.AttributeClass.TypeArguments[0] is INamedTypeSymbol namedService ? + namedService : + null; + + builder.Add(new ServiceAttributeInfo( + lifetime, + key, + serviceType, + attr.ApplicationSyntaxReference?.GetSyntax().GetLocation())); + } + + return builder.ToImmutable(); + } + + static bool IsServiceAttribute(AttributeData attr) => + (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service") && + attr.ConstructorArguments.Length == 1 && + attr.ConstructorArguments[0].Kind == TypedConstantKind.Enum && + attr.ConstructorArguments[0].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; + + static bool IsKeyedServiceAttribute(AttributeData attr) => + (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service" || + attr.AttributeClass?.Name == "KeyedService" || attr.AttributeClass?.Name == "KeyedServiceAttribute") && + attr.ConstructorArguments.Length == 2 && + attr.ConstructorArguments[1].Kind == TypedConstantKind.Enum && + attr.ConstructorArguments[1].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; + + static bool IsExportAttribute(AttributeData attr) + { + var attrName = attr.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return attrName == "global::System.Composition.ExportAttribute" || + attrName == "global::System.ComponentModel.Composition.ExportAttribute"; + } + + static string LifetimeName(int lifetime) => + lifetime switch { 0 => "Singleton", 1 => "Scoped", 2 => "Transient", _ => "Unknown" }; + static string? GetInvokedMethodName(InvocationExpressionSyntax invocation) => invocation.Expression switch { MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.Text, @@ -362,7 +698,40 @@ void ReportInconsistencies(SourceProductionContext context, ImmutableArray Types, Compilation Compilation) data) + static DecoratedService? GetDecoration(InvocationExpressionSyntax invocation, SemanticModel semanticModel, CancellationToken cancellation) + { + var symbolInfo = semanticModel.GetSymbolInfo(invocation, cancellation); + if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) + return null; + + if (!HasMethodAttribute(methodSymbol, "DDIDecorateAttribute")) + return null; + + if (methodSymbol.TypeArguments.Length != 2 || + methodSymbol.TypeArguments[0] is not INamedTypeSymbol decorated || + methodSymbol.TypeArguments[1] is not INamedTypeSymbol decorator) + return null; + + var isKeyed = methodSymbol.Parameters.Any(p => p.Name == "key") || + methodSymbol.ReducedFrom?.Parameters.Any(p => p.Name == "key") == true; + var hasKeyValue = false; + object? keyValue = null; + + if (isKeyed && invocation.ArgumentList.Arguments.Count > 0) + { + var key = semanticModel.GetConstantValue(invocation.ArgumentList.Arguments[0].Expression, cancellation); + hasKeyValue = key.HasValue; + keyValue = key.Value; + } + + return new DecoratedService(decorated, decorator, isKeyed, hasKeyValue, keyValue, invocation.GetLocation()); + } + + static bool HasMethodAttribute(IMethodSymbol method, string attributeName) => + method.GetAttributes().Any(attr => attr.AttributeClass?.Name == attributeName) || + method.ReducedFrom?.GetAttributes().Any(attr => attr.AttributeClass?.Name == attributeName) == true; + + void AddPartial(string methodName, SourceProductionContext ctx, (ImmutableArray Types, ImmutableArray Decorations, Compilation Compilation) data) { if (data.Types.IsEmpty) return; @@ -388,8 +757,8 @@ static partial class AddServicesNoReflectionExtension { """); - AddServices(data.Types.Where(x => x.Key is null), data.Compilation, methodName, builder); - AddKeyedServices(data.Types.Where(x => x.Key is not null), data.Compilation, methodName, builder); + AddServices(data.Types.Where(x => x.Key is null), data.Decorations, data.Compilation, methodName, builder); + AddKeyedServices(data.Types.Where(x => x.Key is not null), data.Decorations, data.Compilation, methodName, builder); builder.AppendLine( """ @@ -401,7 +770,7 @@ static partial class AddServicesNoReflectionExtension ctx.AddSource(methodName + ".g", builder.ToString().Replace("\r\n", "\n").Replace("\n", Environment.NewLine)); } - void AddServices(IEnumerable services, Compilation compilation, string methodName, StringBuilder output) + void AddServices(IEnumerable services, ImmutableArray decorations, Compilation compilation, string methodName, StringBuilder output) { bool isAccessible(ISymbol s) => compilation.IsSymbolAccessible(s); @@ -445,6 +814,9 @@ void AddServices(IEnumerable services, Compilation compilation, st foreach (var iface in serviceTypes) { + if (IsDecoratorServiceAlias(type, iface, keyed: false, decorations)) + continue; + if (!compilation.HasImplicitConversion(type, iface)) continue; @@ -492,7 +864,7 @@ void AddServices(IEnumerable services, Compilation compilation, st } } - void AddKeyedServices(IEnumerable services, Compilation compilation, string methodName, StringBuilder output) + void AddKeyedServices(IEnumerable services, ImmutableArray decorations, Compilation compilation, string methodName, StringBuilder output) { bool isAccessible(ISymbol s) => compilation.IsSymbolAccessible(s); @@ -537,6 +909,9 @@ void AddKeyedServices(IEnumerable services, Compilation compilatio foreach (var iface in serviceTypes) { + if (IsDecoratorServiceAlias(type, iface, keyed: true, decorations)) + continue; + var ifaceName = iface.ToFullName(compilation); if (!registered.Contains(ifaceName)) { diff --git a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs index d3141cb..0cf32d3 100644 --- a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs +++ b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs @@ -1,6 +1,8 @@ // +#nullable enable using System; using System.ComponentModel; +using System.Linq; namespace Microsoft.Extensions.DependencyInjection { @@ -89,6 +91,47 @@ public static IServiceCollection AddServices(this IServiceCollection services) return services; } + /// + /// Decorates registrations for with . + /// + /// + /// The decorated service must already be registered before this method is invoked. The generated + /// implementation replaces matching registrations in-place and preserves each registration lifetime. + /// + /// The service type to decorate. + /// The decorator type. It must implement . + /// The to update. + /// The so that additional calls can be chained. + [DDIDecorate] + public static IServiceCollection Decorate(this IServiceCollection services) + where TDecorated : class + where TDecorator : class, TDecorated + { + DecorateServices(services); + return services; + } + + /// + /// Decorates keyed registrations for with . + /// + /// + /// The decorated service must already be registered with before this method is invoked. + /// The generated implementation replaces matching registrations in-place and preserves each registration lifetime. + /// + /// The service type to decorate. + /// The decorator type. It must implement . + /// The to update. + /// The key for the service registration to decorate. + /// The so that additional calls can be chained. + [DDIDecorate] + public static IServiceCollection Decorate(this IServiceCollection services, object? key) + where TDecorated : class + where TDecorator : class, TDecorated + { + DecorateKeyedServices(services, key); + return services; + } + /// /// Adds discovered scoped services to the collection. /// @@ -119,7 +162,157 @@ public static IServiceCollection AddServices(this IServiceCollection services) /// static partial void AddKeyedTransientServices(IServiceCollection services); + /// + /// Applies source-generated service decorations. + /// + static partial void DecorateServices(IServiceCollection services) + where TDecorated : class + where TDecorator : class, TDecorated; + + /// + /// Applies source-generated keyed service decorations. + /// + static partial void DecorateKeyedServices(IServiceCollection services, object? key) + where TDecorated : class + where TDecorator : class, TDecorated; + + static void DecorateDescriptors( + IServiceCollection services, + Func factory) + where TDecorated : class + where TDecorator : class, TDecorated + { + DecorateDescriptors(services, null, (sp, _, sd) => factory(sp, sd)); + } + + static void DecorateDescriptors( + IServiceCollection services, + object? key, + Func factory) + where TDecorated : class + where TDecorator : class, TDecorated + { + bool isKeyed = key is not null; + + var descriptors = services + .Select((descriptor, index) => new { descriptor, index }) + .Where(x => + x.descriptor.IsKeyedService == isKeyed && + x.descriptor.ServiceType == typeof(TDecorated) && + (!isKeyed || Equals(x.descriptor.ServiceKey, key)) && + GetImplementationType(x.descriptor) != typeof(TDecorator)) + .ToArray(); + + if (descriptors.Length == 0) + { + throw new InvalidOperationException(isKeyed + ? $"No keyed service registration for {typeof(TDecorated)} with key '{key}' was found. Call AddServices before Decorate, or register the decorated service before decorating it." + : $"No service registration for {typeof(TDecorated)} was found. Call AddServices before Decorate, or register the decorated service before decorating it."); + } + + foreach (var item in descriptors) + { + if (!isKeyed) + { + services[item.index] = ServiceDescriptor.Describe( + typeof(TDecorated), + provider => factory(provider, null, item.descriptor), + item.descriptor.Lifetime); + } + else + { + services[item.index] = ServiceDescriptor.DescribeKeyed( + typeof(TDecorated), + key, + (provider, serviceKey) => factory(provider, serviceKey, item.descriptor), + item.descriptor.Lifetime); + } + } + } + + + + static TDecorated GetDecorated(IServiceProvider provider, ServiceDescriptor descriptor) + where TDecorated : class + => GetDecoratedCore(provider, null, descriptor); + + static TDecorated GetKeyedDecorated(IServiceProvider provider, object? key, ServiceDescriptor descriptor) + where TDecorated : class + => GetDecoratedCore(provider, key, descriptor); + + static TDecorated GetDecoratedCore(IServiceProvider provider, object? key, ServiceDescriptor descriptor) + where TDecorated : class + { + bool isKeyed = key is not null; + + if (isKeyed) + { + if (!descriptor.IsKeyedService) + throw new InvalidOperationException($"Non-keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); + } + else + { + if (descriptor.IsKeyedService) + throw new InvalidOperationException($"Keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); + } + + object? service; + + if (descriptor.IsKeyedService) + { + if (descriptor.KeyedImplementationInstance != null) + { + service = descriptor.KeyedImplementationInstance; + } + else if (descriptor.KeyedImplementationFactory != null) + { + service = descriptor.KeyedImplementationFactory(provider, key); + } + else if (descriptor.KeyedImplementationType != null) + { + service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.KeyedImplementationType); + } + else + { + throw new InvalidOperationException($"Unsupported keyed service registration for {typeof(TDecorated)}."); + } + } + else + { + if (descriptor.ImplementationInstance != null) + { + service = descriptor.ImplementationInstance; + } + else if (descriptor.ImplementationFactory != null) + { + service = descriptor.ImplementationFactory(provider); + } + else if (descriptor.ImplementationType != null) + { + service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.ImplementationType); + } + else + { + throw new InvalidOperationException($"Unsupported service registration for {typeof(TDecorated)}."); + } + } + + return service as TDecorated ?? + throw new InvalidOperationException($"The decorated registration did not produce an instance of {typeof(TDecorated)}."); + } + + static Type? GetImplementationType(ServiceDescriptor descriptor) + { + if (descriptor.IsKeyedService) + return descriptor.KeyedImplementationType; + + return descriptor.ImplementationType ?? descriptor.ImplementationInstance?.GetType(); + } + [AttributeUsage(AttributeTargets.Method)] class DDIAddServicesAttribute : Attribute { } + + [AttributeUsage(AttributeTargets.Method)] + class DDIDecorateAttribute : Attribute { } } } \ No newline at end of file diff --git a/src/DependencyInjection.Tests/GenerationTests.cs b/src/DependencyInjection.Tests/GenerationTests.cs index 9cb9379..5307a41 100644 --- a/src/DependencyInjection.Tests/GenerationTests.cs +++ b/src/DependencyInjection.Tests/GenerationTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.ComponentModel.Composition; using System.Linq; using System.Threading; @@ -283,6 +284,135 @@ public void RegisterWithSpecificServiceType() Assert.Null(services.GetService()); } + [Fact] + public void DecorateService() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + using var scope = services.CreateScope(); + + var instance = Assert.IsType(scope.ServiceProvider.GetRequiredService()); + + Assert.IsType(instance.Inner); + Assert.Same(instance, scope.ServiceProvider.GetRequiredService()); + Assert.Same(instance, scope.ServiceProvider.GetRequiredService>().Invoke()); + Assert.Same(instance, scope.ServiceProvider.GetRequiredService>().Value); + Assert.Same(services.GetRequiredService(), instance.Singleton); + } + + [Fact] + public void DecorateMultipleRegistrations() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + var instances = services.GetServices() + .Cast() + .ToList(); + + Assert.Equal(2, instances.Count); + Assert.Contains(instances, x => x.Inner is FirstMultipleDecoratedService); + Assert.Contains(instances, x => x.Inner is SecondMultipleDecoratedService); + } + + [Fact] + public void DecorateMultipleRegistrationsAsIEnumerableDependency() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + var consumer = services.GetRequiredService(); + + var decorated = consumer.Services.Cast().ToList(); + Assert.Equal(2, decorated.Count); + Assert.Contains(decorated, x => x.Inner is FirstMultipleDecoratedService); + Assert.Contains(decorated, x => x.Inner is SecondMultipleDecoratedService); + } + + [Fact] + public void DecorateKeyedService() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate("decorated"); + var services = collection.BuildServiceProvider(); + + var instance = Assert.IsType( + services.GetRequiredKeyedService("decorated")); + + Assert.IsType(instance.Inner); + Assert.Same(services.GetRequiredService(), instance.Singleton); + + var factory = services.GetRequiredKeyedService>("decorated"); + Assert.IsType(factory()); + Assert.IsType(services.GetRequiredKeyedService("other")); + } + + [Fact] + public void DecorateThrowsIfDecoratedServiceIsNotRegistered() + { + var collection = new ServiceCollection(); + + var ex = Assert.Throws(() => + collection.Decorate()); + + Assert.Contains(nameof(IDecoratedService), ex.Message); + } + + [Fact] + public void DecorateKeyedThrowsIfDecoratedServiceIsNotRegistered() + { + var collection = new ServiceCollection(); + collection.AddServices(); + + var ex = Assert.Throws(() => + collection.Decorate("missing")); + + Assert.Contains(nameof(IKeyedDecoratedService), ex.Message); + Assert.Contains("missing", ex.Message); + } + + [Fact] + public void DecorateServiceReadmeExample() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + using var scope = services.CreateScope(); + var provider = scope.ServiceProvider; + + var instance = Assert.IsType(provider.GetRequiredService()); + Assert.IsType(instance.Inner); + + Assert.Same(instance, provider.GetRequiredService>().Invoke()); + Assert.Same(instance, provider.GetRequiredService>().Value); + } + + [Fact] + public void DecorateWorksWithDecoratorThatHasNoServiceAttribute() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + var instance = Assert.IsType(services.GetRequiredService()); + Assert.IsType(instance.Inner); + Assert.Same(services.GetRequiredService(), instance.Singleton); + + Assert.Same(instance, services.GetRequiredService>().Invoke()); + Assert.Same(instance, services.GetRequiredService>().Value); + } + [GenerationTests.Service(ServiceLifetime.Singleton)] public class MyAttributedService : IAsyncDisposable { @@ -406,4 +536,85 @@ public interface INonSpecificService; public class SpecificServiceType : ISpecificService, INonSpecificService { public void Dispose() => throw new NotImplementedException(); -} \ No newline at end of file +} + +public interface IDecoratedService { } + +[Service(ServiceLifetime.Scoped)] +public class DecoratedService : IDecoratedService { } + +[Service(ServiceLifetime.Scoped)] +public class DecoratedServiceDecorator(IDecoratedService inner, SingletonService singleton) : IDecoratedService +{ + public IDecoratedService Inner => inner; + public SingletonService Singleton => singleton; +} + +public interface IMultipleDecoratedService { } + +[Service(ServiceLifetime.Transient)] +public class FirstMultipleDecoratedService : IMultipleDecoratedService { } + +[Service(ServiceLifetime.Transient)] +public class SecondMultipleDecoratedService : IMultipleDecoratedService { } + +[Service(ServiceLifetime.Transient)] +public class MultipleDecoratedServiceDecorator(IMultipleDecoratedService inner) : IMultipleDecoratedService +{ + public IMultipleDecoratedService Inner => inner; +} + +public interface IKeyedDecoratedService { } + +[Service("decorated", ServiceLifetime.Singleton)] +public class KeyedDecoratedService : IKeyedDecoratedService { } + +[Service("other", ServiceLifetime.Singleton)] +public class OtherKeyedDecoratedService : IKeyedDecoratedService { } + +[Service("decorated", ServiceLifetime.Singleton)] +public class KeyedDecoratedServiceDecorator(IKeyedDecoratedService inner, SingletonService singleton) : IKeyedDecoratedService +{ + public IKeyedDecoratedService Inner => inner; + public SingletonService Singleton => singleton; +} + +public interface IReadmeNotificationService +{ + void Send(string message); +} + +[Service(ServiceLifetime.Scoped)] +public class ReadmeEmailNotificationService : IReadmeNotificationService +{ + public void Send(string message) => Console.WriteLine($"[Email] {message}"); +} + +[Service(ServiceLifetime.Scoped)] +public class ReadmeLoggingNotificationService(IReadmeNotificationService inner) : IReadmeNotificationService +{ + public IReadmeNotificationService Inner => inner; + public void Send(string message) + { + Console.WriteLine("Sending notification..."); + inner.Send(message); + } +} + +public interface IUnattributedDecorated { } + +[Service] +public class UnattributedDecoratedImpl : IUnattributedDecorated { } + +// Intentionally no [Service] attribute — decoration should still work. +public class UnattributedDecorator(IUnattributedDecorated inner, SingletonService singleton) : IUnattributedDecorated +{ + public IUnattributedDecorated Inner => inner; + public SingletonService Singleton => singleton; +} + +[Service] +public class IEnumerableDecoratedConsumer(IEnumerable services) +{ + public IEnumerable Services => services; +} diff --git a/src/Samples/ConsoleApp/Program.cs b/src/Samples/ConsoleApp/Program.cs index c78b9fa..7003e34 100644 --- a/src/Samples/ConsoleApp/Program.cs +++ b/src/Samples/ConsoleApp/Program.cs @@ -1,6 +1,7 @@ extern alias Library1; extern alias Library2; +using System.Diagnostics; using Merq; using Microsoft.Extensions.DependencyInjection; @@ -10,9 +11,27 @@ // Library1 contains [Service]-annotated classes, which will be automatically registered here. collection.AddServices(); +// Wrap the echo handler with a stopwatch timer that logs execution time. +collection.Decorate, EchoHandlerTimer>(); + var services = collection.BuildServiceProvider(); var handler = services.GetRequiredService>(); var message = handler.Execute(new Library1::Library.Echo("Hello")); Console.WriteLine(message); + +[Service] +class EchoHandlerTimer(ICommandHandler inner) : ICommandHandler +{ + public bool CanExecute(Library1::Library.Echo command) => inner.CanExecute(command); + + public string Execute(Library1::Library.Echo command) + { + var stopwatch = Stopwatch.StartNew(); + var result = inner.Execute(command); + stopwatch.Stop(); + Console.WriteLine($"Echo executed in {stopwatch.Elapsed.TotalMilliseconds:F3} ms"); + return result; + } +} \ No newline at end of file