From 6c82f08d31f6765fb53c6847d121c6683920bcd5 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sun, 28 Jun 2026 20:49:50 +0000 Subject: [PATCH 1/9] Add source-generated Decorate marker Co-authored-by: Daniel Cazzulino --- .../DecorateGeneratorTests.cs | 148 +++++++ .../IncrementalGenerator.cs | 402 ++++++++++++++++-- .../AddServicesNoReflectionExtension.cs | 99 +++++ .../GenerationTests.cs | 73 ++++ 4 files changed, 677 insertions(+), 45 deletions(-) create mode 100644 src/CodeAnalysis.Tests/DecorateGeneratorTests.cs diff --git a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs new file mode 100644 index 0000000..f1cca08 --- /dev/null +++ b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs @@ -0,0 +1,148 @@ +using System.Collections.Immutable; +using System.IO; +using System.Threading.Tasks; +using Devlooped.Extensions.DependencyInjection; +using Microsoft.CodeAnalysis.CSharp.Testing; +using Microsoft.CodeAnalysis.Testing; +using Xunit; +using Xunit.Abstractions; +using Verifier = Microsoft.CodeAnalysis.CSharp.Testing.CSharpAnalyzerVerifier; + +namespace Tests.CodeAnalysis; + +public class DecorateGeneratorTests(ITestOutputHelper Output) +{ + [Fact] + public async Task ErrorIfDecoratorIsNotService() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service] + public class Foo : IFoo { } + + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.{|#0:Decorate()|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorMustBeService) + .WithLocation(0) + .WithArguments("FooDecorator")); + + await test.RunAsync(); + } + + [Fact] + public async Task ErrorIfDecoratorLifetimeIsIncompatible() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service(ServiceLifetime.Scoped)] + public class Foo : IFoo { } + + [Service(ServiceLifetime.Singleton)] + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.{|#0:Decorate()|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorLifetimeIncompatible) + .WithLocation(0) + .WithArguments("FooDecorator", "Singleton", "IFoo", "Scoped")); + + await test.RunAsync(); + } + + [Fact] + public async Task ErrorIfDecoratorConstructorDoesNotAcceptDecoratedService() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service] + public class Foo : IFoo { } + + [Service] + public class FooDecorator : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.{|#0:Decorate()|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorConstructorMissing) + .WithLocation(0) + .WithArguments("FooDecorator", "IFoo")); + + await test.RunAsync(); + } + + static CSharpSourceGeneratorTest CreateTest(string source) + { + return new CSharpSourceGeneratorTest + { + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + TestCode = source, + TestState = + { + AnalyzerConfigFiles = + { + ("/.editorconfig", + """ + is_global = true + build_property.AddServicesExtension = true + """) + }, + Sources = + { + ThisAssembly.Resources.AddServicesNoReflectionExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text + }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) + .AddPackages(ImmutableArray.Create( + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) + }, + }.WithPreprocessorSymbols(); + } +} diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index 1a8acba..24f2596 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -11,6 +11,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.Extensions.DependencyInjection; using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TImplementation, Microsoft.CodeAnalysis.INamedTypeSymbol? TService, Microsoft.CodeAnalysis.TypedConstant? Key); +using DecoratedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TDecorated, Microsoft.CodeAnalysis.INamedTypeSymbol TDecorator, Microsoft.CodeAnalysis.Location? Location); namespace Devlooped.Extensions.DependencyInjection; @@ -30,6 +31,33 @@ public class IncrementalGenerator : IIncrementalGenerator DiagnosticSeverity.Warning, isEnabledByDefault: true); + public static DiagnosticDescriptor DecoratorMustBeService { get; } = + new DiagnosticDescriptor( + "DDI006", + "Decorator must be annotated with ServiceAttribute.", + "Decorator type {0} must be annotated with [Service] so its registration is generated before decoration.", + "Build", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static DiagnosticDescriptor DecoratorLifetimeIncompatible { get; } = + new DiagnosticDescriptor( + "DDI007", + "Decorator lifetime is incompatible with decorated services.", + "Decorator type {0} has lifetime {1}, which is incompatible with decorated service {2} lifetimes {3}.", + "Build", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static DiagnosticDescriptor DecoratorConstructorMissing { get; } = + new DiagnosticDescriptor( + "DDI008", + "Decorator constructor must accept the decorated service.", + "Decorator type {0} must have an accessible constructor with exactly one parameter of type {1}.", + "Build", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + class ServiceSymbol(INamedTypeSymbol implementation, int lifetime, TypedConstant? key, Location? location, INamedTypeSymbol? service) { public INamedTypeSymbol TImplementation => implementation; @@ -66,6 +94,8 @@ record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullN public Regex Regex => (regex ??= FullNameExpression is not null ? new(FullNameExpression) : new(".*")); } + record ServiceAttributeInfo(int Lifetime, TypedConstant? Key, INamedTypeSymbol? ServiceType, Location? Location); + public void Initialize(IncrementalGeneratorInitializationContext context) { var types = context.CompilationProvider.Combine(context.AnalyzerConfigOptionsProvider).SelectMany((x, c) => @@ -92,28 +122,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return visitor.TypeSymbols.Where(t => !t.IsAbstract && t.TypeKind == TypeKind.Class); }); - bool IsService(AttributeData attr) => - (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service") && - attr.ConstructorArguments.Length == 1 && - attr.ConstructorArguments[0].Kind == TypedConstantKind.Enum && - attr.ConstructorArguments[0].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; - - bool IsKeyedService(AttributeData attr) => - (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service" || - attr.AttributeClass?.Name == "KeyedService" || attr.AttributeClass?.Name == "KeyedServiceAttribute") && - //attr.AttributeClass?.IsGenericType == true && - attr.ConstructorArguments.Length == 2 && - attr.ConstructorArguments[1].Kind == TypedConstantKind.Enum && - attr.ConstructorArguments[1].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; - - bool IsExport(AttributeData attr) - { - var attrName = attr.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - return attrName == "global::System.Composition.ExportAttribute" || - attrName == "global::System.ComponentModel.Composition.ExportAttribute"; - } - ; - // NOTE: we recognize the attribute by name, not precise type. This makes the generator // more flexible and avoids requiring any sort of run-time dependency. @@ -129,8 +137,8 @@ bool IsExport(AttributeData attr) foreach (var attr in attrs) { - var serviceAttr = IsService(attr) || IsKeyedService(attr) ? attr : null; - if (serviceAttr == null && !IsExport(attr)) + var serviceAttr = IsServiceAttribute(attr) || IsKeyedServiceAttribute(attr) ? attr : null; + if (serviceAttr == null && !IsExportAttribute(attr)) continue; TypedConstant? key = default; @@ -139,7 +147,7 @@ bool IsExport(AttributeData attr) var lifetime = serviceAttr != null ? 0 : 2; if (serviceAttr != null) { - if (IsKeyedService(serviceAttr)) + if (IsKeyedServiceAttribute(serviceAttr)) { key = serviceAttr.ConstructorArguments[0]; lifetime = (int)serviceAttr.ConstructorArguments[1].Value!; @@ -149,7 +157,7 @@ bool IsExport(AttributeData attr) lifetime = (int)serviceAttr.ConstructorArguments[0].Value!; } } - else if (IsExport(attr)) + else if (IsExportAttribute(attr)) { // In NuGet MEF, [Shared] makes exports singleton if (attrs.Any(a => a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.Composition.SharedAttribute")) @@ -216,6 +224,20 @@ serviceAttrSymbol.Symbol is IMethodSymbol attrCtor && .Select((x, _) => x.Left) .Collect(); + var decorations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is InvocationExpressionSyntax invocation && GetInvokedMethodName(invocation) == "Decorate", + transform: static (ctx, cancellation) => GetDecoration((InvocationExpressionSyntax)ctx.Node, ctx.SemanticModel, cancellation)) + .Combine(context.AnalyzerConfigOptionsProvider) + .Where(x => + { + (var decoration, var options) = x; + return options.GlobalOptions.TryGetValue("build_property.AddServicesExtension", out var value) && + bool.TryParse(value, out var addServices) && addServices && decoration is not null; + }) + .Select((x, _) => x.Left!) + .Collect(); + // Project matching service types to register with the given lifetime. var conventionServices = types.Combine(methodInvocations.Combine(context.CompilationProvider)).SelectMany((pair, cancellationToken) => { @@ -245,38 +267,54 @@ serviceAttrSymbol.Symbol is IMethodSymbol attrCtor && .SelectMany((tuple, _) => ImmutableArray.CreateRange([tuple.Item1, tuple.Item2])) .SelectMany((items, _) => items.Distinct().ToImmutableArray()); - RegisterServicesOutput(context, finalServices, context.CompilationProvider); + RegisterServicesOutput(context, finalServices, decorations, context.CompilationProvider); + RegisterDecorateOutput(context, decorations, finalServices.Collect(), context.CompilationProvider); } - void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider services, IncrementalValueProvider compilation) + void RegisterServicesOutput( + IncrementalGeneratorInitializationContext context, + IncrementalValuesProvider services, + IncrementalValueProvider> decorations, + IncrementalValueProvider compilation) { context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 0 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddSingleton", ctx, data)); + services.Where(x => x!.Lifetime == 0 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddSingleton", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 1 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddScoped", ctx, data)); + services.Where(x => x!.Lifetime == 1 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddScoped", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 2 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddTransient", ctx, data)); + services.Where(x => x!.Lifetime == 2 && x.Key is null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, null)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddTransient", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 0 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddKeyedSingleton", ctx, data)); + services.Where(x => x!.Lifetime == 0 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddKeyedSingleton", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 1 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddKeyedScoped", ctx, data)); + services.Where(x => x!.Lifetime == 1 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddKeyedScoped", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(compilation), - (ctx, data) => AddPartial("AddKeyedTransient", ctx, data)); + services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.TImplementation, x!.TService, x.Key!)).Collect().Combine(decorations).Combine(compilation), + (ctx, data) => AddPartial("AddKeyedTransient", ctx, (data.Left.Left, data.Left.Right, data.Right))); context.RegisterImplementationSourceOutput(services.Collect(), ReportInconsistencies); } + void RegisterDecorateOutput( + IncrementalGeneratorInitializationContext context, + IncrementalValueProvider> decorations, + IncrementalValueProvider> services, + IncrementalValueProvider compilation) + { + context.RegisterImplementationSourceOutput( + decorations.Combine(services).Combine(compilation), + (ctx, data) => AddDecoratePartial(ctx, data.Left.Left, data.Left.Right, data.Right)); + } + void ReportInconsistencies(SourceProductionContext context, ImmutableArray array) { var grouped = array.GroupBy(x => x.TImplementation, SymbolEqualityComparer.Default).Where(g => g.Count() > 1).ToImmutableArray(); @@ -304,6 +342,253 @@ void ReportInconsistencies(SourceProductionContext context, ImmutableArray decorations, + ImmutableArray services, + Compilation compilation) + { + if (decorations.IsEmpty) + return; + + var validDecorations = ImmutableArray.CreateBuilder<(DecoratedService Decoration, IMethodSymbol Constructor)>(); + + foreach (var decoration in decorations) + { + if (!ValidateDecoration(ctx, decoration, services, compilation, out var constructor)) + continue; + + validDecorations.Add((decoration, constructor!)); + } + + if (validDecorations.Count == 0) + return; + + var builder = new StringBuilder() + .AppendLine("// "); + + foreach (var alias in compilation.References.SelectMany(r => r.Properties.Aliases)) + { + builder.AppendLine($"extern alias {alias};"); + } + + builder.AppendLine( + """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + static partial class AddServicesNoReflectionExtension + { + static partial void DecorateServices(IServiceCollection services) + where TDecorated : class + where TDecorator : class, TDecorated + { + """); + + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, _) = validDecorations[i]; + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + + builder.AppendLine($" if (typeof(TDecorated) == typeof({decorated}) && typeof(TDecorator) == typeof({decorator}))"); + builder.AppendLine(" {"); + builder.AppendLine($" DecorateDescriptors<{decorated}, {decorator}>(services, CreateDecorator{i});"); + builder.AppendLine(" return;"); + builder.AppendLine(" }"); + } + + builder.AppendLine( + """ + } + """); + + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, ctor) = validDecorations[i]; + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + var usedDecorated = false; + var args = string.Join(", ", ctor.Parameters.Select(p => + { + if (!usedDecorated && SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) + { + usedDecorated = true; + return $"GetDecorated<{decorated}>(s, descriptor)"; + } + + var fromKeyed = p.GetAttributes().FirstOrDefault(IsFromKeyed); + if (fromKeyed is not null) + return $"s.GetRequiredKeyedService<{p.Type.ToFullName(compilation)}>({fromKeyed.ConstructorArguments[0].ToCSharpString()})"; + + return $"s.GetRequiredService<{p.Type.ToFullName(compilation)}>()"; + })); + + builder.AppendLine(); + builder.AppendLine($" static {decorated} CreateDecorator{i}(IServiceProvider s, ServiceDescriptor descriptor)"); + builder.AppendLine($" => new {decorator}({args});"); + } + + builder.AppendLine( + """ + } + } + """); + + ctx.AddSource("Decorate.g", builder.ToString().Replace("\r\n", "\n").Replace("\n", Environment.NewLine)); + } + + bool ValidateDecoration( + SourceProductionContext ctx, + DecoratedService decoration, + ImmutableArray services, + Compilation compilation, + out IMethodSymbol? constructor) + { + constructor = GetDecoratorConstructor(decoration, compilation); + var isValid = true; + var decoratorLifetimes = GetDecoratorLifetimes(decoration, compilation); + + if (decoratorLifetimes.IsEmpty) + { + ctx.ReportDiagnostic(Diagnostic.Create( + DecoratorMustBeService, + decoration.Location, + decoration.TDecorator.ToDisplayString())); + isValid = false; + } + + var decoratedLifetimes = GetDecoratedLifetimes(decoration, services, compilation); + if (!decoratorLifetimes.IsEmpty && !decoratedLifetimes.IsEmpty && + (decoratorLifetimes.Length != 1 || decoratedLifetimes.Any(x => x != decoratorLifetimes[0]))) + { + ctx.ReportDiagnostic(Diagnostic.Create( + DecoratorLifetimeIncompatible, + decoration.Location, + decoration.TDecorator.ToDisplayString(), + string.Join(", ", decoratorLifetimes.Select(LifetimeName)), + decoration.TDecorated.ToDisplayString(), + string.Join(", ", decoratedLifetimes.Select(LifetimeName)))); + isValid = false; + } + + if (constructor is null) + { + ctx.ReportDiagnostic(Diagnostic.Create( + DecoratorConstructorMissing, + decoration.Location, + decoration.TDecorator.ToDisplayString(), + decoration.TDecorated.ToDisplayString())); + isValid = false; + } + + return isValid; + } + + static ImmutableArray GetDecoratorLifetimes(DecoratedService decoration, Compilation compilation) + { + return GetServiceAttributes(decoration.TDecorator) + .Where(x => x.Key is null) + .Where(x => + x.ServiceType is null ? + compilation.HasImplicitConversion(decoration.TDecorator, decoration.TDecorated) : + SymbolEqualityComparer.Default.Equals(x.ServiceType, decoration.TDecorated)) + .Select(x => x.Lifetime) + .Distinct() + .ToImmutableArray(); + } + + static ImmutableArray GetDecoratedLifetimes(DecoratedService decoration, ImmutableArray services, Compilation compilation) + { + return services + .Where(x => x.Key is null) + .Where(x => !SymbolEqualityComparer.Default.Equals(x.TImplementation, decoration.TDecorator)) + .Where(x => + x.TService is null ? + compilation.HasImplicitConversion(x.TImplementation, decoration.TDecorated) : + SymbolEqualityComparer.Default.Equals(x.TService, decoration.TDecorated)) + .Select(x => x.Lifetime) + .Distinct() + .ToImmutableArray(); + } + + IMethodSymbol? GetDecoratorConstructor(DecoratedService decoration, Compilation compilation) + { + var candidates = decoration.TDecorator.InstanceConstructors + .Where(x => compilation.IsSymbolAccessible(x)) + .Where(x => x.Parameters.Count(p => SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) == 1) + .ToImmutableArray(); + + if (candidates.IsDefaultOrEmpty) + return null; + + return candidates.FirstOrDefault(HasImportingConstructor) ?? + candidates.OrderByDescending(m => m.Parameters.Length).FirstOrDefault(); + } + + static bool HasImportingConstructor(IMethodSymbol method) => + method.GetAttributes().Any(a => + a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.Composition.ImportingConstructorAttribute" || + a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.ComponentModel.Composition.ImportingConstructorAttribute"); + + static bool IsDecoratorServiceAlias(INamedTypeSymbol implementation, INamedTypeSymbol service, ImmutableArray decorations) => + decorations.Any(x => + SymbolEqualityComparer.Default.Equals(x.TDecorator, implementation) && + SymbolEqualityComparer.Default.Equals(x.TDecorated, service)); + + static ImmutableArray GetServiceAttributes(INamedTypeSymbol type) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var attr in type.GetAttributes()) + { + if (!IsServiceAttribute(attr) && !IsKeyedServiceAttribute(attr)) + continue; + + var lifetime = IsKeyedServiceAttribute(attr) ? + (int)attr.ConstructorArguments[1].Value! : + (int)attr.ConstructorArguments[0].Value!; + var key = IsKeyedServiceAttribute(attr) ? attr.ConstructorArguments[0] : (TypedConstant?)null; + var serviceType = attr.AttributeClass?.IsGenericType == true && + attr.AttributeClass.TypeArguments.Length == 1 && + attr.AttributeClass.TypeArguments[0] is INamedTypeSymbol namedService ? + namedService : + null; + + builder.Add(new ServiceAttributeInfo( + lifetime, + key, + serviceType, + attr.ApplicationSyntaxReference?.GetSyntax().GetLocation())); + } + + return builder.ToImmutable(); + } + + static bool IsServiceAttribute(AttributeData attr) => + (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service") && + attr.ConstructorArguments.Length == 1 && + attr.ConstructorArguments[0].Kind == TypedConstantKind.Enum && + attr.ConstructorArguments[0].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; + + static bool IsKeyedServiceAttribute(AttributeData attr) => + (attr.AttributeClass?.Name == "ServiceAttribute" || attr.AttributeClass?.Name == "Service" || + attr.AttributeClass?.Name == "KeyedService" || attr.AttributeClass?.Name == "KeyedServiceAttribute") && + attr.ConstructorArguments.Length == 2 && + attr.ConstructorArguments[1].Kind == TypedConstantKind.Enum && + attr.ConstructorArguments[1].Type?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::Microsoft.Extensions.DependencyInjection.ServiceLifetime"; + + static bool IsExportAttribute(AttributeData attr) + { + var attrName = attr.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return attrName == "global::System.Composition.ExportAttribute" || + attrName == "global::System.ComponentModel.Composition.ExportAttribute"; + } + + static string LifetimeName(int lifetime) => + lifetime switch { 0 => "Singleton", 1 => "Scoped", 2 => "Transient", _ => "Unknown" }; + static string? GetInvokedMethodName(InvocationExpressionSyntax invocation) => invocation.Expression switch { MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.Text, @@ -362,7 +647,28 @@ void ReportInconsistencies(SourceProductionContext context, ImmutableArray Types, Compilation Compilation) data) + static DecoratedService? GetDecoration(InvocationExpressionSyntax invocation, SemanticModel semanticModel, CancellationToken cancellation) + { + var symbolInfo = semanticModel.GetSymbolInfo(invocation, cancellation); + if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) + return null; + + if (!HasMethodAttribute(methodSymbol, "DDIDecorateAttribute")) + return null; + + if (methodSymbol.TypeArguments.Length != 2 || + methodSymbol.TypeArguments[0] is not INamedTypeSymbol decorated || + methodSymbol.TypeArguments[1] is not INamedTypeSymbol decorator) + return null; + + return new DecoratedService(decorated, decorator, invocation.GetLocation()); + } + + static bool HasMethodAttribute(IMethodSymbol method, string attributeName) => + method.GetAttributes().Any(attr => attr.AttributeClass?.Name == attributeName) || + method.ReducedFrom?.GetAttributes().Any(attr => attr.AttributeClass?.Name == attributeName) == true; + + void AddPartial(string methodName, SourceProductionContext ctx, (ImmutableArray Types, ImmutableArray Decorations, Compilation Compilation) data) { if (data.Types.IsEmpty) return; @@ -388,8 +694,8 @@ static partial class AddServicesNoReflectionExtension { """); - AddServices(data.Types.Where(x => x.Key is null), data.Compilation, methodName, builder); - AddKeyedServices(data.Types.Where(x => x.Key is not null), data.Compilation, methodName, builder); + AddServices(data.Types.Where(x => x.Key is null), data.Decorations, data.Compilation, methodName, builder); + AddKeyedServices(data.Types.Where(x => x.Key is not null), data.Decorations, data.Compilation, methodName, builder); builder.AppendLine( """ @@ -401,7 +707,7 @@ static partial class AddServicesNoReflectionExtension ctx.AddSource(methodName + ".g", builder.ToString().Replace("\r\n", "\n").Replace("\n", Environment.NewLine)); } - void AddServices(IEnumerable services, Compilation compilation, string methodName, StringBuilder output) + void AddServices(IEnumerable services, ImmutableArray decorations, Compilation compilation, string methodName, StringBuilder output) { bool isAccessible(ISymbol s) => compilation.IsSymbolAccessible(s); @@ -445,6 +751,9 @@ void AddServices(IEnumerable services, Compilation compilation, st foreach (var iface in serviceTypes) { + if (IsDecoratorServiceAlias(type, iface, decorations)) + continue; + if (!compilation.HasImplicitConversion(type, iface)) continue; @@ -492,7 +801,7 @@ void AddServices(IEnumerable services, Compilation compilation, st } } - void AddKeyedServices(IEnumerable services, Compilation compilation, string methodName, StringBuilder output) + void AddKeyedServices(IEnumerable services, ImmutableArray decorations, Compilation compilation, string methodName, StringBuilder output) { bool isAccessible(ISymbol s) => compilation.IsSymbolAccessible(s); @@ -537,6 +846,9 @@ void AddKeyedServices(IEnumerable services, Compilation compilatio foreach (var iface in serviceTypes) { + if (IsDecoratorServiceAlias(type, iface, decorations)) + continue; + var ifaceName = iface.ToFullName(compilation); if (!registered.Contains(ifaceName)) { diff --git a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs index d3141cb..048371d 100644 --- a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs +++ b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs @@ -1,6 +1,7 @@ // using System; using System.ComponentModel; +using System.Linq; namespace Microsoft.Extensions.DependencyInjection { @@ -89,6 +90,26 @@ public static IServiceCollection AddServices(this IServiceCollection services) return services; } + /// + /// Decorates registrations for with . + /// + /// + /// The decorated service must already be registered before this method is invoked. The generated + /// implementation replaces matching registrations in-place and preserves each registration lifetime. + /// + /// The service type to decorate. + /// The decorator type. It must implement . + /// The to update. + /// The so that additional calls can be chained. + [DDIDecorate] + public static IServiceCollection Decorate(this IServiceCollection services) + where TDecorated : class + where TDecorator : class, TDecorated + { + DecorateServices(services); + return services; + } + /// /// Adds discovered scoped services to the collection. /// @@ -119,7 +140,85 @@ public static IServiceCollection AddServices(this IServiceCollection services) /// static partial void AddKeyedTransientServices(IServiceCollection services); + /// + /// Applies source-generated service decorations. + /// + static partial void DecorateServices(IServiceCollection services) + where TDecorated : class + where TDecorator : class, TDecorated; + + static void DecorateDescriptors( + IServiceCollection services, + Func factory) + where TDecorated : class + where TDecorator : class, TDecorated + { + var descriptors = services + .Select((descriptor, index) => new { descriptor, index }) + .Where(x => + !x.descriptor.IsKeyedService && + x.descriptor.ServiceType == typeof(TDecorated) && + GetImplementationType(x.descriptor) != typeof(TDecorator)) + .ToArray(); + + if (descriptors.Length == 0) + { + throw new InvalidOperationException( + $"No service registration for {typeof(TDecorated)} was found. Call AddServices before Decorate, or register the decorated service before decorating it."); + } + + foreach (var item in descriptors) + { + services[item.index] = ServiceDescriptor.Describe( + typeof(TDecorated), + provider => factory(provider, item.descriptor), + item.descriptor.Lifetime); + } + } + + static TDecorated GetDecorated(IServiceProvider provider, ServiceDescriptor descriptor) + where TDecorated : class + { + object? service; + + if (descriptor.IsKeyedService) + { + throw new InvalidOperationException($"Keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); + } + + if (descriptor.ImplementationInstance != null) + { + service = descriptor.ImplementationInstance; + } + else if (descriptor.ImplementationFactory != null) + { + service = descriptor.ImplementationFactory(provider); + } + else if (descriptor.ImplementationType != null) + { + service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.ImplementationType); + } + else + { + throw new InvalidOperationException($"Unsupported service registration for {typeof(TDecorated)}."); + } + + return service as TDecorated ?? + throw new InvalidOperationException($"The decorated registration did not produce an instance of {typeof(TDecorated)}."); + } + + static Type? GetImplementationType(ServiceDescriptor descriptor) + { + if (descriptor.IsKeyedService) + return descriptor.KeyedImplementationType; + + return descriptor.ImplementationType ?? descriptor.ImplementationInstance?.GetType(); + } + [AttributeUsage(AttributeTargets.Method)] class DDIAddServicesAttribute : Attribute { } + + [AttributeUsage(AttributeTargets.Method)] + class DDIDecorateAttribute : Attribute { } } } \ No newline at end of file diff --git a/src/DependencyInjection.Tests/GenerationTests.cs b/src/DependencyInjection.Tests/GenerationTests.cs index 9cb9379..991801d 100644 --- a/src/DependencyInjection.Tests/GenerationTests.cs +++ b/src/DependencyInjection.Tests/GenerationTests.cs @@ -283,6 +283,53 @@ public void RegisterWithSpecificServiceType() Assert.Null(services.GetService()); } + [Fact] + public void DecorateService() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + using var scope = services.CreateScope(); + + var instance = Assert.IsType(scope.ServiceProvider.GetRequiredService()); + + Assert.IsType(instance.Inner); + Assert.Same(instance, scope.ServiceProvider.GetRequiredService()); + Assert.Same(instance, scope.ServiceProvider.GetRequiredService>().Invoke()); + Assert.Same(instance, scope.ServiceProvider.GetRequiredService>().Value); + Assert.Same(services.GetRequiredService(), instance.Singleton); + } + + [Fact] + public void DecorateMultipleRegistrations() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + var instances = services.GetServices() + .Cast() + .ToList(); + + Assert.Equal(2, instances.Count); + Assert.Contains(instances, x => x.Inner is FirstMultipleDecoratedService); + Assert.Contains(instances, x => x.Inner is SecondMultipleDecoratedService); + } + + [Fact] + public void DecorateThrowsIfDecoratedServiceIsNotRegistered() + { + var collection = new ServiceCollection(); + + var ex = Assert.Throws(() => + collection.Decorate()); + + Assert.Contains(nameof(IDecoratedService), ex.Message); + } + [GenerationTests.Service(ServiceLifetime.Singleton)] public class MyAttributedService : IAsyncDisposable { @@ -406,4 +453,30 @@ public interface INonSpecificService; public class SpecificServiceType : ISpecificService, INonSpecificService { public void Dispose() => throw new NotImplementedException(); +} + +public interface IDecoratedService { } + +[Service(ServiceLifetime.Scoped)] +public class DecoratedService : IDecoratedService { } + +[Service(ServiceLifetime.Scoped)] +public class DecoratedServiceDecorator(IDecoratedService inner, SingletonService singleton) : IDecoratedService +{ + public IDecoratedService Inner => inner; + public SingletonService Singleton => singleton; +} + +public interface IMultipleDecoratedService { } + +[Service(ServiceLifetime.Transient)] +public class FirstMultipleDecoratedService : IMultipleDecoratedService { } + +[Service(ServiceLifetime.Transient)] +public class SecondMultipleDecoratedService : IMultipleDecoratedService { } + +[Service(ServiceLifetime.Transient)] +public class MultipleDecoratedServiceDecorator(IMultipleDecoratedService inner) : IMultipleDecoratedService +{ + public IMultipleDecoratedService Inner => inner; } \ No newline at end of file From 1212abfa1b0f838f6c933afe7cfc009a00de6c0f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sun, 28 Jun 2026 20:52:15 +0000 Subject: [PATCH 2/9] Fix decorate generator test build Co-authored-by: Daniel Cazzulino --- src/CodeAnalysis.Tests/DecorateGeneratorTests.cs | 6 +++--- .../IncrementalGenerator.cs | 2 +- .../compile/AddServicesNoReflectionExtension.cs | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs index f1cca08..040ba14 100644 --- a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs +++ b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs @@ -32,7 +32,7 @@ public static void Main() { var services = new ServiceCollection(); services.AddServices(); - services.{|#0:Decorate()|}; + {|#0:services.Decorate()|}; } } """); @@ -66,7 +66,7 @@ public static void Main() { var services = new ServiceCollection(); services.AddServices(); - services.{|#0:Decorate()|}; + {|#0:services.Decorate()|}; } } """); @@ -100,7 +100,7 @@ public static void Main() { var services = new ServiceCollection(); services.AddServices(); - services.{|#0:Decorate()|}; + {|#0:services.Decorate()|}; } } """); diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index 24f2596..02d71f6 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -235,7 +235,7 @@ serviceAttrSymbol.Symbol is IMethodSymbol attrCtor && return options.GlobalOptions.TryGetValue("build_property.AddServicesExtension", out var value) && bool.TryParse(value, out var addServices) && addServices && decoration is not null; }) - .Select((x, _) => x.Left!) + .Select((x, _) => x.Left!.Value) .Collect(); // Project matching service types to register with the given lifetime. diff --git a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs index 048371d..e14cd13 100644 --- a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs +++ b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs @@ -1,4 +1,5 @@ // +#nullable enable using System; using System.ComponentModel; using System.Linq; From a4cc8764ac414e240df4148e5037f572437cae8d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sun, 28 Jun 2026 20:56:09 +0000 Subject: [PATCH 3/9] Clarify decorator constructor dependencies Co-authored-by: Daniel Cazzulino --- .../DecorateGeneratorTests.cs | 33 +++++++++++++++++++ .../IncrementalGenerator.cs | 2 +- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs index 040ba14..948ed67 100644 --- a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs +++ b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs @@ -113,6 +113,39 @@ public static void Main() await test.RunAsync(); } + [Fact] + public async Task NoErrorIfDecoratorConstructorHasOtherDependencies() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + public interface IOtherDependency { } + + [Service] + public class Foo : IFoo { } + + [Service] + public class OtherDependency : IOtherDependency { } + + [Service] + public class FooDecorator(IFoo inner, IOtherDependency other) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.Decorate(); + } + } + """); + + await test.RunAsync(); + } + static CSharpSourceGeneratorTest CreateTest(string source) { return new CSharpSourceGeneratorTest diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index 02d71f6..f64c06c 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -53,7 +53,7 @@ public class IncrementalGenerator : IIncrementalGenerator new DiagnosticDescriptor( "DDI008", "Decorator constructor must accept the decorated service.", - "Decorator type {0} must have an accessible constructor with exactly one parameter of type {1}.", + "Decorator type {0} must have an accessible constructor with exactly one parameter of type {1}; additional dependencies are allowed.", "Build", DiagnosticSeverity.Error, isEnabledByDefault: true); From ae1727bd3559d3134c0bcfe5660a6b8596fe25cd Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sun, 28 Jun 2026 22:23:33 +0000 Subject: [PATCH 4/9] Add keyed Decorate overload Co-authored-by: Daniel Cazzulino --- .../DecorateGeneratorTests.cs | 66 +++++++++++++ .../IncrementalGenerator.cs | 95 +++++++++++++++++-- .../AddServicesNoReflectionExtension.cs | 91 ++++++++++++++++++ .../GenerationTests.cs | 47 +++++++++ 4 files changed, 291 insertions(+), 8 deletions(-) diff --git a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs index 948ed67..568ef25 100644 --- a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs +++ b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs @@ -146,6 +146,72 @@ public static void Main() await test.RunAsync(); } + [Fact] + public async Task NoErrorIfKeyedDecoratorLifetimeMatchesSelectedKey() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service("foo", ServiceLifetime.Scoped)] + public class Foo : IFoo { } + + [Service("bar", ServiceLifetime.Singleton)] + public class OtherFoo : IFoo { } + + [Service("foo", ServiceLifetime.Scoped)] + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + services.Decorate("foo"); + } + } + """); + + await test.RunAsync(); + } + + [Fact] + public async Task ErrorIfKeyedDecoratorLifetimeIsIncompatible() + { + var test = CreateTest( + """ + using Microsoft.Extensions.DependencyInjection; + + public interface IFoo { } + + [Service("foo", ServiceLifetime.Scoped)] + public class Foo : IFoo { } + + [Service("foo", ServiceLifetime.Singleton)] + public class FooDecorator(IFoo inner) : IFoo { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(); + {|#0:services.Decorate("foo")|}; + } + } + """); + + test.ExpectedDiagnostics.Add( + Verifier.Diagnostic(IncrementalGenerator.DecoratorLifetimeIncompatible) + .WithLocation(0) + .WithArguments("FooDecorator", "Singleton", "IFoo", "Scoped")); + + await test.RunAsync(); + } + static CSharpSourceGeneratorTest CreateTest(string source) { return new CSharpSourceGeneratorTest diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index f64c06c..3eb2182 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -11,7 +11,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.Extensions.DependencyInjection; using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TImplementation, Microsoft.CodeAnalysis.INamedTypeSymbol? TService, Microsoft.CodeAnalysis.TypedConstant? Key); -using DecoratedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TDecorated, Microsoft.CodeAnalysis.INamedTypeSymbol TDecorator, Microsoft.CodeAnalysis.Location? Location); +using DecoratedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TDecorated, Microsoft.CodeAnalysis.INamedTypeSymbol TDecorator, bool IsKeyed, bool HasKeyValue, object? KeyValue, Microsoft.CodeAnalysis.Location? Location); namespace Devlooped.Extensions.DependencyInjection; @@ -365,7 +365,8 @@ void AddDecoratePartial( return; var builder = new StringBuilder() - .AppendLine("// "); + .AppendLine("// ") + .AppendLine("#nullable enable"); foreach (var alias in compilation.References.SelectMany(r => r.Properties.Aliases)) { @@ -389,6 +390,9 @@ static partial void DecorateServices(IServiceCollection for (var i = 0; i < validDecorations.Count; i++) { var (decoration, _) = validDecorations[i]; + if (decoration.IsKeyed) + continue; + var decorated = decoration.TDecorated.ToFullName(compilation); var decorator = decoration.TDecorator.ToFullName(compilation); @@ -399,6 +403,32 @@ static partial void DecorateServices(IServiceCollection builder.AppendLine(" }"); } + builder.AppendLine( + """ + } + + static partial void DecorateKeyedServices(IServiceCollection services, object? key) + where TDecorated : class + where TDecorator : class, TDecorated + { + """); + + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, _) = validDecorations[i]; + if (!decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + + builder.AppendLine($" if (typeof(TDecorated) == typeof({decorated}) && typeof(TDecorator) == typeof({decorator}))"); + builder.AppendLine(" {"); + builder.AppendLine($" DecorateKeyedDescriptors<{decorated}, {decorator}>(services, key, CreateKeyedDecorator{i});"); + builder.AppendLine(" return;"); + builder.AppendLine(" }"); + } + builder.AppendLine( """ } @@ -407,6 +437,9 @@ static partial void DecorateServices(IServiceCollection for (var i = 0; i < validDecorations.Count; i++) { var (decoration, ctor) = validDecorations[i]; + if (decoration.IsKeyed) + continue; + var decorated = decoration.TDecorated.ToFullName(compilation); var decorator = decoration.TDecorator.ToFullName(compilation); var usedDecorated = false; @@ -430,6 +463,35 @@ static partial void DecorateServices(IServiceCollection builder.AppendLine($" => new {decorator}({args});"); } + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, ctor) = validDecorations[i]; + if (!decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + var usedDecorated = false; + var args = string.Join(", ", ctor.Parameters.Select(p => + { + if (!usedDecorated && SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) + { + usedDecorated = true; + return $"GetKeyedDecorated<{decorated}>(s, key, descriptor)"; + } + + var fromKeyed = p.GetAttributes().FirstOrDefault(IsFromKeyed); + if (fromKeyed is not null) + return $"s.GetRequiredKeyedService<{p.Type.ToFullName(compilation)}>({fromKeyed.ConstructorArguments[0].ToCSharpString()})"; + + return $"s.GetRequiredService<{p.Type.ToFullName(compilation)}>()"; + })); + + builder.AppendLine(); + builder.AppendLine($" static {decorated} CreateKeyedDecorator{i}(IServiceProvider s, object? key, ServiceDescriptor descriptor)"); + builder.AppendLine($" => new {decorator}({args});"); + } + builder.AppendLine( """ } @@ -489,7 +551,7 @@ bool ValidateDecoration( static ImmutableArray GetDecoratorLifetimes(DecoratedService decoration, Compilation compilation) { return GetServiceAttributes(decoration.TDecorator) - .Where(x => x.Key is null) + .Where(x => IsMatchingDecoratorAttribute(decoration, x)) .Where(x => x.ServiceType is null ? compilation.HasImplicitConversion(decoration.TDecorator, decoration.TDecorated) : @@ -502,7 +564,8 @@ x.ServiceType is null ? static ImmutableArray GetDecoratedLifetimes(DecoratedService decoration, ImmutableArray services, Compilation compilation) { return services - .Where(x => x.Key is null) + .Where(x => decoration.IsKeyed ? x.Key is not null : x.Key is null) + .Where(x => !decoration.IsKeyed || !decoration.HasKeyValue || Equals(x.Key?.Value, decoration.KeyValue)) .Where(x => !SymbolEqualityComparer.Default.Equals(x.TImplementation, decoration.TDecorator)) .Where(x => x.TService is null ? @@ -513,6 +576,9 @@ x.TService is null ? .ToImmutableArray(); } + static bool IsMatchingDecoratorAttribute(DecoratedService decoration, ServiceAttributeInfo attribute) => + decoration.IsKeyed || attribute.Key is null; + IMethodSymbol? GetDecoratorConstructor(DecoratedService decoration, Compilation compilation) { var candidates = decoration.TDecorator.InstanceConstructors @@ -532,8 +598,9 @@ static bool HasImportingConstructor(IMethodSymbol method) => a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.Composition.ImportingConstructorAttribute" || a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == "global::System.ComponentModel.Composition.ImportingConstructorAttribute"); - static bool IsDecoratorServiceAlias(INamedTypeSymbol implementation, INamedTypeSymbol service, ImmutableArray decorations) => + static bool IsDecoratorServiceAlias(INamedTypeSymbol implementation, INamedTypeSymbol service, bool keyed, ImmutableArray decorations) => decorations.Any(x => + x.IsKeyed == keyed && SymbolEqualityComparer.Default.Equals(x.TDecorator, implementation) && SymbolEqualityComparer.Default.Equals(x.TDecorated, service)); @@ -661,7 +728,19 @@ methodSymbol.TypeArguments[0] is not INamedTypeSymbol decorated || methodSymbol.TypeArguments[1] is not INamedTypeSymbol decorator) return null; - return new DecoratedService(decorated, decorator, invocation.GetLocation()); + var isKeyed = methodSymbol.Parameters.Any(p => p.Name == "key") || + methodSymbol.ReducedFrom?.Parameters.Any(p => p.Name == "key") == true; + var hasKeyValue = false; + object? keyValue = null; + + if (isKeyed && invocation.ArgumentList.Arguments.Count > 0) + { + var key = semanticModel.GetConstantValue(invocation.ArgumentList.Arguments[0].Expression, cancellation); + hasKeyValue = key.HasValue; + keyValue = key.Value; + } + + return new DecoratedService(decorated, decorator, isKeyed, hasKeyValue, keyValue, invocation.GetLocation()); } static bool HasMethodAttribute(IMethodSymbol method, string attributeName) => @@ -751,7 +830,7 @@ void AddServices(IEnumerable services, ImmutableArray services, ImmutableArray(this IServiceC return services; } + /// + /// Decorates keyed registrations for with . + /// + /// + /// The decorated service must already be registered with before this method is invoked. + /// The generated implementation replaces matching registrations in-place and preserves each registration lifetime. + /// + /// The service type to decorate. + /// The decorator type. It must implement . + /// The to update. + /// The key for the service registration to decorate. + /// The so that additional calls can be chained. + [DDIDecorate] + public static IServiceCollection Decorate(this IServiceCollection services, object? key) + where TDecorated : class + where TDecorator : class, TDecorated + { + DecorateKeyedServices(services, key); + return services; + } + /// /// Adds discovered scoped services to the collection. /// @@ -148,6 +169,13 @@ static partial void DecorateServices(IServiceCollection where TDecorated : class where TDecorator : class, TDecorated; + /// + /// Applies source-generated keyed service decorations. + /// + static partial void DecorateKeyedServices(IServiceCollection services, object? key) + where TDecorated : class + where TDecorator : class, TDecorated; + static void DecorateDescriptors( IServiceCollection services, Func factory) @@ -177,6 +205,38 @@ static void DecorateDescriptors( } } + static void DecorateKeyedDescriptors( + IServiceCollection services, + object? key, + Func factory) + where TDecorated : class + where TDecorator : class, TDecorated + { + var descriptors = services + .Select((descriptor, index) => new { descriptor, index }) + .Where(x => + x.descriptor.IsKeyedService && + x.descriptor.ServiceType == typeof(TDecorated) && + Equals(x.descriptor.ServiceKey, key) && + GetImplementationType(x.descriptor) != typeof(TDecorator)) + .ToArray(); + + if (descriptors.Length == 0) + { + throw new InvalidOperationException( + $"No keyed service registration for {typeof(TDecorated)} with key '{key}' was found. Call AddServices before Decorate, or register the decorated service before decorating it."); + } + + foreach (var item in descriptors) + { + services[item.index] = ServiceDescriptor.DescribeKeyed( + typeof(TDecorated), + key, + (provider, serviceKey) => factory(provider, serviceKey, item.descriptor), + item.descriptor.Lifetime); + } + } + static TDecorated GetDecorated(IServiceProvider provider, ServiceDescriptor descriptor) where TDecorated : class { @@ -208,6 +268,37 @@ static TDecorated GetDecorated(IServiceProvider provider, ServiceDes throw new InvalidOperationException($"The decorated registration did not produce an instance of {typeof(TDecorated)}."); } + static TDecorated GetKeyedDecorated(IServiceProvider provider, object? key, ServiceDescriptor descriptor) + where TDecorated : class + { + object? service; + + if (!descriptor.IsKeyedService) + { + throw new InvalidOperationException($"Non-keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); + } + + if (descriptor.KeyedImplementationInstance != null) + { + service = descriptor.KeyedImplementationInstance; + } + else if (descriptor.KeyedImplementationFactory != null) + { + service = descriptor.KeyedImplementationFactory(provider, key); + } + else if (descriptor.KeyedImplementationType != null) + { + service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.KeyedImplementationType); + } + else + { + throw new InvalidOperationException($"Unsupported keyed service registration for {typeof(TDecorated)}."); + } + + return service as TDecorated ?? + throw new InvalidOperationException($"The decorated keyed registration did not produce an instance of {typeof(TDecorated)}."); + } + static Type? GetImplementationType(ServiceDescriptor descriptor) { if (descriptor.IsKeyedService) diff --git a/src/DependencyInjection.Tests/GenerationTests.cs b/src/DependencyInjection.Tests/GenerationTests.cs index 991801d..ad1fdef 100644 --- a/src/DependencyInjection.Tests/GenerationTests.cs +++ b/src/DependencyInjection.Tests/GenerationTests.cs @@ -319,6 +319,25 @@ public void DecorateMultipleRegistrations() Assert.Contains(instances, x => x.Inner is SecondMultipleDecoratedService); } + [Fact] + public void DecorateKeyedService() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate("decorated"); + var services = collection.BuildServiceProvider(); + + var instance = Assert.IsType( + services.GetRequiredKeyedService("decorated")); + + Assert.IsType(instance.Inner); + Assert.Same(services.GetRequiredService(), instance.Singleton); + + var factory = services.GetRequiredKeyedService>("decorated"); + Assert.IsType(factory()); + Assert.IsType(services.GetRequiredKeyedService("other")); + } + [Fact] public void DecorateThrowsIfDecoratedServiceIsNotRegistered() { @@ -330,6 +349,19 @@ public void DecorateThrowsIfDecoratedServiceIsNotRegistered() Assert.Contains(nameof(IDecoratedService), ex.Message); } + [Fact] + public void DecorateKeyedThrowsIfDecoratedServiceIsNotRegistered() + { + var collection = new ServiceCollection(); + collection.AddServices(); + + var ex = Assert.Throws(() => + collection.Decorate("missing")); + + Assert.Contains(nameof(IKeyedDecoratedService), ex.Message); + Assert.Contains("missing", ex.Message); + } + [GenerationTests.Service(ServiceLifetime.Singleton)] public class MyAttributedService : IAsyncDisposable { @@ -479,4 +511,19 @@ public class SecondMultipleDecoratedService : IMultipleDecoratedService { } public class MultipleDecoratedServiceDecorator(IMultipleDecoratedService inner) : IMultipleDecoratedService { public IMultipleDecoratedService Inner => inner; +} + +public interface IKeyedDecoratedService { } + +[Service("decorated", ServiceLifetime.Singleton)] +public class KeyedDecoratedService : IKeyedDecoratedService { } + +[Service("other", ServiceLifetime.Singleton)] +public class OtherKeyedDecoratedService : IKeyedDecoratedService { } + +[Service("decorated", ServiceLifetime.Singleton)] +public class KeyedDecoratedServiceDecorator(IKeyedDecoratedService inner, SingletonService singleton) : IKeyedDecoratedService +{ + public IKeyedDecoratedService Inner => inner; + public SingletonService Singleton => singleton; } \ No newline at end of file From e36dd1c19aae9d6ed249c52f3cc73366d5dc64e6 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Sun, 28 Jun 2026 19:44:20 -0300 Subject: [PATCH 5/9] Apply dotnet format --- src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index 3eb2182..3d63715 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -10,8 +10,8 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.Extensions.DependencyInjection; -using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TImplementation, Microsoft.CodeAnalysis.INamedTypeSymbol? TService, Microsoft.CodeAnalysis.TypedConstant? Key); using DecoratedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TDecorated, Microsoft.CodeAnalysis.INamedTypeSymbol TDecorator, bool IsKeyed, bool HasKeyValue, object? KeyValue, Microsoft.CodeAnalysis.Location? Location); +using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol TImplementation, Microsoft.CodeAnalysis.INamedTypeSymbol? TService, Microsoft.CodeAnalysis.TypedConstant? Key); namespace Devlooped.Extensions.DependencyInjection; From 1b8e2bdbc6046afeb502a4d8a022bfcf19a4c68b Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Sun, 28 Jun 2026 19:52:57 -0300 Subject: [PATCH 6/9] Align Roslyn testing packages to 1.1.4 to fix NU1608 warnings. Co-authored-by: Cursor --- src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj b/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj index b623a81..b3595bf 100644 --- a/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj +++ b/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj @@ -12,9 +12,9 @@ - - - + + + From 3afadfe9bb812515c36d8febdefb728c1afc40cd Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Sun, 28 Jun 2026 19:58:42 -0300 Subject: [PATCH 7/9] Add Decorate sample and document decorating services Showcase the Decorate feature in the sample app with a stopwatch timer decorator on the echo command handler, and document usage in the readme including keyed decoration. --- readme.md | 55 +++++++++++++++++++++++++++++++ src/Samples/ConsoleApp/Program.cs | 19 +++++++++++ 2 files changed, 74 insertions(+) diff --git a/readme.md b/readme.md index 35dcdbf..d1a997f 100644 --- a/readme.md +++ b/readme.md @@ -174,6 +174,61 @@ Note you can also register the same service using multiple keys, as shown in the > [!IMPORTANT] > Keyed services are a feature of version 8.0+ of Microsoft.Extensions.DependencyInjection +### Decorating Services + +After services are registered with `AddServices`, you can wrap existing registrations +with a decorator using the `Decorate()` extension method. +The decorated service must already be registered, and the source generator replaces +matching registrations in-place while preserving each registration's lifetime. + +The decorator type must implement `TDecorated`, be annotated with `[Service]`, and +provide a constructor that accepts the decorated service as one of its parameters +(additional dependencies are resolved from the container as usual): + +```csharp +public interface INotificationService +{ + void Send(string message); +} + +[Service(ServiceLifetime.Scoped)] +public class EmailNotificationService : INotificationService +{ + public void Send(string message) => Console.WriteLine($"[Email] {message}"); +} + +[Service(ServiceLifetime.Scoped)] +public class LoggingNotificationService(INotificationService inner) : INotificationService +{ + public void Send(string message) + { + Console.WriteLine("Sending notification..."); + inner.Send(message); + } +} + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddServices(); +builder.Services.Decorate(); +``` + +When resolving `INotificationService`, the container returns a `LoggingNotificationService` +that wraps the original `EmailNotificationService`. `Func` and `Lazy` registrations +for the same service type also resolve the decorated instance. + +If multiple registrations exist for the same service type, all of them are decorated. +For keyed services, pass the key to decorate a specific registration: + +```csharp +builder.Services.AddServices(); +builder.Services.Decorate("email"); +``` + +The generator validates decorations at compile-time: the decorator must be a registered +service, its lifetime must be compatible with the decorated registration, and its +constructor must accept the decorated service type. + ## How It Works In all cases, the generated code that implements the registration looks like the following: diff --git a/src/Samples/ConsoleApp/Program.cs b/src/Samples/ConsoleApp/Program.cs index c78b9fa..7003e34 100644 --- a/src/Samples/ConsoleApp/Program.cs +++ b/src/Samples/ConsoleApp/Program.cs @@ -1,6 +1,7 @@ extern alias Library1; extern alias Library2; +using System.Diagnostics; using Merq; using Microsoft.Extensions.DependencyInjection; @@ -10,9 +11,27 @@ // Library1 contains [Service]-annotated classes, which will be automatically registered here. collection.AddServices(); +// Wrap the echo handler with a stopwatch timer that logs execution time. +collection.Decorate, EchoHandlerTimer>(); + var services = collection.BuildServiceProvider(); var handler = services.GetRequiredService>(); var message = handler.Execute(new Library1::Library.Echo("Hello")); Console.WriteLine(message); + +[Service] +class EchoHandlerTimer(ICommandHandler inner) : ICommandHandler +{ + public bool CanExecute(Library1::Library.Echo command) => inner.CanExecute(command); + + public string Execute(Library1::Library.Echo command) + { + var stopwatch = Stopwatch.StartNew(); + var result = inner.Execute(command); + stopwatch.Stop(); + Console.WriteLine($"Echo executed in {stopwatch.Elapsed.TotalMilliseconds:F3} ms"); + return result; + } +} \ No newline at end of file From 82d175287d4d67c29544d368d38c9b2c77ec68b9 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Mon, 29 Jun 2026 06:05:32 -0300 Subject: [PATCH 8/9] feat(decorate): make [Service] optional on decorators + deduplicate keyed paths - [Service] is no longer required on TDecorator for Decorate. Lifetime compatibility validation is now only performed when the decorator actually declares [Service] (or keyed variant) attributes. - Reduced duplication between keyed and non-keyed code: - Runtime helpers: unified DecorateDescriptors (now handles key=null for non-keyed) and GetDecoratedCore (single implementation of the Instance/Factory/ImplementationType materialization logic). - Generator: extracted BuildDecoratorArgs helper to eliminate duplicated constructor argument emission for CreateDecoratorN / CreateKeyedDecoratorN. - Cleanups: - Removed DecoratorMustBeService (DDI006) diagnostic and the associated generator test (ErrorIfDecoratorIsNotService). - Removed legacy DecorateKeyedDescriptors forwarding method; generator now emits calls to the unified DecorateDescriptors overload. - Updated readme documentation to reflect that [Service] on the decorator is optional. - Added NoErrorIfDecoratorHasNoServiceAttribute (generator) and DecorateWorksWithDecoratorThatHasNoServiceAttribute (end-to-end) tests. - Added the README 'Decorating Services' example as DecorateServiceReadmeExample in GenerationTests. - Existing behavior preserved when [Service] is present on the decorator (lifetime checks, suppression via IsDecoratorServiceAlias, etc.). --- readme.md | 14 +- .../DecorateGeneratorTests.cs | 47 +++--- .../IncrementalGenerator.cs | 70 ++++----- .../AddServicesNoReflectionExtension.cs | 146 +++++++++--------- .../GenerationTests.cs | 68 ++++++++ 5 files changed, 199 insertions(+), 146 deletions(-) diff --git a/readme.md b/readme.md index d1a997f..55be9eb 100644 --- a/readme.md +++ b/readme.md @@ -181,9 +181,10 @@ with a decorator using the `Decorate()` extension method The decorated service must already be registered, and the source generator replaces matching registrations in-place while preserving each registration's lifetime. -The decorator type must implement `TDecorated`, be annotated with `[Service]`, and -provide a constructor that accepts the decorated service as one of its parameters -(additional dependencies are resolved from the container as usual): +The decorator type must implement `TDecorated` and provide a constructor that accepts +the decorated service as one of its parameters (additional dependencies are resolved +from the container as usual). Annotating the decorator with `[Service]` is optional +(when present, lifetime compatibility with the decorated service is validated at compile time): ```csharp public interface INotificationService @@ -225,9 +226,10 @@ builder.Services.AddServices(); builder.Services.Decorate("email"); ``` -The generator validates decorations at compile-time: the decorator must be a registered -service, its lifetime must be compatible with the decorated registration, and its -constructor must accept the decorated service type. +The generator validates decorations at compile-time: the decorator must have a constructor +that accepts the decorated service type (plus any additional dependencies). If the decorator +is annotated with `[Service]`, its lifetime is also validated for compatibility with the +decorated registration(s). ## How It Works diff --git a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs index 568ef25..4a10477 100644 --- a/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs +++ b/src/CodeAnalysis.Tests/DecorateGeneratorTests.cs @@ -13,7 +13,7 @@ namespace Tests.CodeAnalysis; public class DecorateGeneratorTests(ITestOutputHelper Output) { [Fact] - public async Task ErrorIfDecoratorIsNotService() + public async Task ErrorIfDecoratorLifetimeIsIncompatible() { var test = CreateTest( """ @@ -21,9 +21,10 @@ public async Task ErrorIfDecoratorIsNotService() public interface IFoo { } - [Service] + [Service(ServiceLifetime.Scoped)] public class Foo : IFoo { } + [Service(ServiceLifetime.Singleton)] public class FooDecorator(IFoo inner) : IFoo { } public static class Program @@ -38,15 +39,15 @@ public static void Main() """); test.ExpectedDiagnostics.Add( - Verifier.Diagnostic(IncrementalGenerator.DecoratorMustBeService) + Verifier.Diagnostic(IncrementalGenerator.DecoratorLifetimeIncompatible) .WithLocation(0) - .WithArguments("FooDecorator")); + .WithArguments("FooDecorator", "Singleton", "IFoo", "Scoped")); await test.RunAsync(); } [Fact] - public async Task ErrorIfDecoratorLifetimeIsIncompatible() + public async Task ErrorIfDecoratorConstructorDoesNotAcceptDecoratedService() { var test = CreateTest( """ @@ -54,11 +55,11 @@ public async Task ErrorIfDecoratorLifetimeIsIncompatible() public interface IFoo { } - [Service(ServiceLifetime.Scoped)] + [Service] public class Foo : IFoo { } - [Service(ServiceLifetime.Singleton)] - public class FooDecorator(IFoo inner) : IFoo { } + [Service] + public class FooDecorator : IFoo { } public static class Program { @@ -72,27 +73,31 @@ public static void Main() """); test.ExpectedDiagnostics.Add( - Verifier.Diagnostic(IncrementalGenerator.DecoratorLifetimeIncompatible) + Verifier.Diagnostic(IncrementalGenerator.DecoratorConstructorMissing) .WithLocation(0) - .WithArguments("FooDecorator", "Singleton", "IFoo", "Scoped")); + .WithArguments("FooDecorator", "IFoo")); await test.RunAsync(); } [Fact] - public async Task ErrorIfDecoratorConstructorDoesNotAcceptDecoratedService() + public async Task NoErrorIfDecoratorConstructorHasOtherDependencies() { var test = CreateTest( """ using Microsoft.Extensions.DependencyInjection; public interface IFoo { } + public interface IOtherDependency { } [Service] public class Foo : IFoo { } [Service] - public class FooDecorator : IFoo { } + public class OtherDependency : IOtherDependency { } + + [Service] + public class FooDecorator(IFoo inner, IOtherDependency other) : IFoo { } public static class Program { @@ -100,37 +105,28 @@ public static void Main() { var services = new ServiceCollection(); services.AddServices(); - {|#0:services.Decorate()|}; + services.Decorate(); } } """); - test.ExpectedDiagnostics.Add( - Verifier.Diagnostic(IncrementalGenerator.DecoratorConstructorMissing) - .WithLocation(0) - .WithArguments("FooDecorator", "IFoo")); - await test.RunAsync(); } [Fact] - public async Task NoErrorIfDecoratorConstructorHasOtherDependencies() + public async Task NoErrorIfDecoratorHasNoServiceAttribute() { var test = CreateTest( """ using Microsoft.Extensions.DependencyInjection; public interface IFoo { } - public interface IOtherDependency { } [Service] public class Foo : IFoo { } - [Service] - public class OtherDependency : IOtherDependency { } - - [Service] - public class FooDecorator(IFoo inner, IOtherDependency other) : IFoo { } + // Decorator intentionally has NO [Service] attribute + public class FooDecorator(IFoo inner) : IFoo { } public static class Program { @@ -143,6 +139,7 @@ public static void Main() } """); + // No diagnostics expected — decorator no longer requires [Service] await test.RunAsync(); } diff --git a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs index 3d63715..7a1dbe0 100644 --- a/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs +++ b/src/DependencyInjection.CodeAnalysis/IncrementalGenerator.cs @@ -31,15 +31,6 @@ public class IncrementalGenerator : IIncrementalGenerator DiagnosticSeverity.Warning, isEnabledByDefault: true); - public static DiagnosticDescriptor DecoratorMustBeService { get; } = - new DiagnosticDescriptor( - "DDI006", - "Decorator must be annotated with ServiceAttribute.", - "Decorator type {0} must be annotated with [Service] so its registration is generated before decoration.", - "Build", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - public static DiagnosticDescriptor DecoratorLifetimeIncompatible { get; } = new DiagnosticDescriptor( "DDI007", @@ -387,6 +378,7 @@ static partial void DecorateServices(IServiceCollection { """); + // Emit non-keyed dispatch for (var i = 0; i < validDecorations.Count; i++) { var (decoration, _) = validDecorations[i]; @@ -413,6 +405,7 @@ static partial void DecorateKeyedServices(IServiceCollec { """); + // Emit keyed dispatch for (var i = 0; i < validDecorations.Count; i++) { var (decoration, _) = validDecorations[i]; @@ -424,7 +417,7 @@ static partial void DecorateKeyedServices(IServiceCollec builder.AppendLine($" if (typeof(TDecorated) == typeof({decorated}) && typeof(TDecorator) == typeof({decorator}))"); builder.AppendLine(" {"); - builder.AppendLine($" DecorateKeyedDescriptors<{decorated}, {decorator}>(services, key, CreateKeyedDecorator{i});"); + builder.AppendLine($" DecorateDescriptors<{decorated}, {decorator}>(services, key, CreateKeyedDecorator{i});"); builder.AppendLine(" return;"); builder.AppendLine(" }"); } @@ -434,21 +427,22 @@ static partial void DecorateKeyedServices(IServiceCollec } """); - for (var i = 0; i < validDecorations.Count; i++) + // Helper to build constructor arguments for a decorator factory. + // Handles the special case for the decorated service parameter and [FromKeyedServices]. + string BuildDecoratorArgs((DecoratedService Decoration, IMethodSymbol Constructor) entry, Compilation compilation, bool isKeyed) { - var (decoration, ctor) = validDecorations[i]; - if (decoration.IsKeyed) - continue; - + var (decoration, ctor) = entry; var decorated = decoration.TDecorated.ToFullName(compilation); - var decorator = decoration.TDecorator.ToFullName(compilation); - var usedDecorated = false; - var args = string.Join(", ", ctor.Parameters.Select(p => + bool usedDecorated = false; + + return string.Join(", ", ctor.Parameters.Select(p => { if (!usedDecorated && SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) { usedDecorated = true; - return $"GetDecorated<{decorated}>(s, descriptor)"; + return isKeyed + ? $"GetKeyedDecorated<{decorated}>(s, key, descriptor)" + : $"GetDecorated<{decorated}>(s, descriptor)"; } var fromKeyed = p.GetAttributes().FirstOrDefault(IsFromKeyed); @@ -457,12 +451,25 @@ static partial void DecorateKeyedServices(IServiceCollec return $"s.GetRequiredService<{p.Type.ToFullName(compilation)}>()"; })); + } + + // Emit non-keyed factory methods + for (var i = 0; i < validDecorations.Count; i++) + { + var (decoration, ctor) = validDecorations[i]; + if (decoration.IsKeyed) + continue; + + var decorated = decoration.TDecorated.ToFullName(compilation); + var decorator = decoration.TDecorator.ToFullName(compilation); + var args = BuildDecoratorArgs((decoration, ctor), compilation, isKeyed: false); builder.AppendLine(); builder.AppendLine($" static {decorated} CreateDecorator{i}(IServiceProvider s, ServiceDescriptor descriptor)"); builder.AppendLine($" => new {decorator}({args});"); } + // Emit keyed factory methods for (var i = 0; i < validDecorations.Count; i++) { var (decoration, ctor) = validDecorations[i]; @@ -471,21 +478,7 @@ static partial void DecorateKeyedServices(IServiceCollec var decorated = decoration.TDecorated.ToFullName(compilation); var decorator = decoration.TDecorator.ToFullName(compilation); - var usedDecorated = false; - var args = string.Join(", ", ctor.Parameters.Select(p => - { - if (!usedDecorated && SymbolEqualityComparer.Default.Equals(p.Type, decoration.TDecorated)) - { - usedDecorated = true; - return $"GetKeyedDecorated<{decorated}>(s, key, descriptor)"; - } - - var fromKeyed = p.GetAttributes().FirstOrDefault(IsFromKeyed); - if (fromKeyed is not null) - return $"s.GetRequiredKeyedService<{p.Type.ToFullName(compilation)}>({fromKeyed.ConstructorArguments[0].ToCSharpString()})"; - - return $"s.GetRequiredService<{p.Type.ToFullName(compilation)}>()"; - })); + var args = BuildDecoratorArgs((decoration, ctor), compilation, isKeyed: true); builder.AppendLine(); builder.AppendLine($" static {decorated} CreateKeyedDecorator{i}(IServiceProvider s, object? key, ServiceDescriptor descriptor)"); @@ -512,15 +505,6 @@ bool ValidateDecoration( var isValid = true; var decoratorLifetimes = GetDecoratorLifetimes(decoration, compilation); - if (decoratorLifetimes.IsEmpty) - { - ctx.ReportDiagnostic(Diagnostic.Create( - DecoratorMustBeService, - decoration.Location, - decoration.TDecorator.ToDisplayString())); - isValid = false; - } - var decoratedLifetimes = GetDecoratedLifetimes(decoration, services, compilation); if (!decoratorLifetimes.IsEmpty && !decoratedLifetimes.IsEmpty && (decoratorLifetimes.Length != 1 || decoratedLifetimes.Any(x => x != decoratorLifetimes[0]))) diff --git a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs index 0a85fca..0cf32d3 100644 --- a/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs +++ b/src/DependencyInjection.CodeAnalysis/compile/AddServicesNoReflectionExtension.cs @@ -182,121 +182,123 @@ static void DecorateDescriptors( where TDecorated : class where TDecorator : class, TDecorated { - var descriptors = services - .Select((descriptor, index) => new { descriptor, index }) - .Where(x => - !x.descriptor.IsKeyedService && - x.descriptor.ServiceType == typeof(TDecorated) && - GetImplementationType(x.descriptor) != typeof(TDecorator)) - .ToArray(); - - if (descriptors.Length == 0) - { - throw new InvalidOperationException( - $"No service registration for {typeof(TDecorated)} was found. Call AddServices before Decorate, or register the decorated service before decorating it."); - } - - foreach (var item in descriptors) - { - services[item.index] = ServiceDescriptor.Describe( - typeof(TDecorated), - provider => factory(provider, item.descriptor), - item.descriptor.Lifetime); - } + DecorateDescriptors(services, null, (sp, _, sd) => factory(sp, sd)); } - static void DecorateKeyedDescriptors( + static void DecorateDescriptors( IServiceCollection services, object? key, Func factory) where TDecorated : class where TDecorator : class, TDecorated { + bool isKeyed = key is not null; + var descriptors = services .Select((descriptor, index) => new { descriptor, index }) .Where(x => - x.descriptor.IsKeyedService && + x.descriptor.IsKeyedService == isKeyed && x.descriptor.ServiceType == typeof(TDecorated) && - Equals(x.descriptor.ServiceKey, key) && + (!isKeyed || Equals(x.descriptor.ServiceKey, key)) && GetImplementationType(x.descriptor) != typeof(TDecorator)) .ToArray(); if (descriptors.Length == 0) { - throw new InvalidOperationException( - $"No keyed service registration for {typeof(TDecorated)} with key '{key}' was found. Call AddServices before Decorate, or register the decorated service before decorating it."); + throw new InvalidOperationException(isKeyed + ? $"No keyed service registration for {typeof(TDecorated)} with key '{key}' was found. Call AddServices before Decorate, or register the decorated service before decorating it." + : $"No service registration for {typeof(TDecorated)} was found. Call AddServices before Decorate, or register the decorated service before decorating it."); } foreach (var item in descriptors) { - services[item.index] = ServiceDescriptor.DescribeKeyed( - typeof(TDecorated), - key, - (provider, serviceKey) => factory(provider, serviceKey, item.descriptor), - item.descriptor.Lifetime); + if (!isKeyed) + { + services[item.index] = ServiceDescriptor.Describe( + typeof(TDecorated), + provider => factory(provider, null, item.descriptor), + item.descriptor.Lifetime); + } + else + { + services[item.index] = ServiceDescriptor.DescribeKeyed( + typeof(TDecorated), + key, + (provider, serviceKey) => factory(provider, serviceKey, item.descriptor), + item.descriptor.Lifetime); + } } } + + static TDecorated GetDecorated(IServiceProvider provider, ServiceDescriptor descriptor) where TDecorated : class - { - object? service; + => GetDecoratedCore(provider, null, descriptor); - if (descriptor.IsKeyedService) - { - throw new InvalidOperationException($"Keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); - } + static TDecorated GetKeyedDecorated(IServiceProvider provider, object? key, ServiceDescriptor descriptor) + where TDecorated : class + => GetDecoratedCore(provider, key, descriptor); - if (descriptor.ImplementationInstance != null) - { - service = descriptor.ImplementationInstance; - } - else if (descriptor.ImplementationFactory != null) - { - service = descriptor.ImplementationFactory(provider); - } - else if (descriptor.ImplementationType != null) + static TDecorated GetDecoratedCore(IServiceProvider provider, object? key, ServiceDescriptor descriptor) + where TDecorated : class + { + bool isKeyed = key is not null; + + if (isKeyed) { - service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.ImplementationType); + if (!descriptor.IsKeyedService) + throw new InvalidOperationException($"Non-keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); } else { - throw new InvalidOperationException($"Unsupported service registration for {typeof(TDecorated)}."); + if (descriptor.IsKeyedService) + throw new InvalidOperationException($"Keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); } - return service as TDecorated ?? - throw new InvalidOperationException($"The decorated registration did not produce an instance of {typeof(TDecorated)}."); - } - - static TDecorated GetKeyedDecorated(IServiceProvider provider, object? key, ServiceDescriptor descriptor) - where TDecorated : class - { object? service; - if (!descriptor.IsKeyedService) - { - throw new InvalidOperationException($"Non-keyed service registrations for {typeof(TDecorated)} cannot be decorated by this overload."); - } - - if (descriptor.KeyedImplementationInstance != null) - { - service = descriptor.KeyedImplementationInstance; - } - else if (descriptor.KeyedImplementationFactory != null) - { - service = descriptor.KeyedImplementationFactory(provider, key); - } - else if (descriptor.KeyedImplementationType != null) + if (descriptor.IsKeyedService) { - service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.KeyedImplementationType); + if (descriptor.KeyedImplementationInstance != null) + { + service = descriptor.KeyedImplementationInstance; + } + else if (descriptor.KeyedImplementationFactory != null) + { + service = descriptor.KeyedImplementationFactory(provider, key); + } + else if (descriptor.KeyedImplementationType != null) + { + service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.KeyedImplementationType); + } + else + { + throw new InvalidOperationException($"Unsupported keyed service registration for {typeof(TDecorated)}."); + } } else { - throw new InvalidOperationException($"Unsupported keyed service registration for {typeof(TDecorated)}."); + if (descriptor.ImplementationInstance != null) + { + service = descriptor.ImplementationInstance; + } + else if (descriptor.ImplementationFactory != null) + { + service = descriptor.ImplementationFactory(provider); + } + else if (descriptor.ImplementationType != null) + { + service = ActivatorUtilities.GetServiceOrCreateInstance(provider, descriptor.ImplementationType); + } + else + { + throw new InvalidOperationException($"Unsupported service registration for {typeof(TDecorated)}."); + } } return service as TDecorated ?? - throw new InvalidOperationException($"The decorated keyed registration did not produce an instance of {typeof(TDecorated)}."); + throw new InvalidOperationException($"The decorated registration did not produce an instance of {typeof(TDecorated)}."); } static Type? GetImplementationType(ServiceDescriptor descriptor) diff --git a/src/DependencyInjection.Tests/GenerationTests.cs b/src/DependencyInjection.Tests/GenerationTests.cs index ad1fdef..cc95ed5 100644 --- a/src/DependencyInjection.Tests/GenerationTests.cs +++ b/src/DependencyInjection.Tests/GenerationTests.cs @@ -362,6 +362,40 @@ public void DecorateKeyedThrowsIfDecoratedServiceIsNotRegistered() Assert.Contains("missing", ex.Message); } + [Fact] + public void DecorateServiceReadmeExample() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + using var scope = services.CreateScope(); + var provider = scope.ServiceProvider; + + var instance = Assert.IsType(provider.GetRequiredService()); + Assert.IsType(instance.Inner); + + Assert.Same(instance, provider.GetRequiredService>().Invoke()); + Assert.Same(instance, provider.GetRequiredService>().Value); + } + + [Fact] + public void DecorateWorksWithDecoratorThatHasNoServiceAttribute() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + var instance = Assert.IsType(services.GetRequiredService()); + Assert.IsType(instance.Inner); + Assert.Same(services.GetRequiredService(), instance.Singleton); + + Assert.Same(instance, services.GetRequiredService>().Invoke()); + Assert.Same(instance, services.GetRequiredService>().Value); + } + [GenerationTests.Service(ServiceLifetime.Singleton)] public class MyAttributedService : IAsyncDisposable { @@ -526,4 +560,38 @@ public class KeyedDecoratedServiceDecorator(IKeyedDecoratedService inner, Single { public IKeyedDecoratedService Inner => inner; public SingletonService Singleton => singleton; +} + +public interface IReadmeNotificationService +{ + void Send(string message); +} + +[Service(ServiceLifetime.Scoped)] +public class ReadmeEmailNotificationService : IReadmeNotificationService +{ + public void Send(string message) => Console.WriteLine($"[Email] {message}"); +} + +[Service(ServiceLifetime.Scoped)] +public class ReadmeLoggingNotificationService(IReadmeNotificationService inner) : IReadmeNotificationService +{ + public IReadmeNotificationService Inner => inner; + public void Send(string message) + { + Console.WriteLine("Sending notification..."); + inner.Send(message); + } +} + +public interface IUnattributedDecorated { } + +[Service] +public class UnattributedDecoratedImpl : IUnattributedDecorated { } + +// Intentionally no [Service] attribute — decoration should still work. +public class UnattributedDecorator(IUnattributedDecorated inner, SingletonService singleton) : IUnattributedDecorated +{ + public IUnattributedDecorated Inner => inner; + public SingletonService Singleton => singleton; } \ No newline at end of file From ec4ee629795cd6b4cab5463119e635ff34285799 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Mon, 29 Jun 2026 06:15:57 -0300 Subject: [PATCH 9/9] Ensure IEnumerable resolves decorated instances Add regression test verifying that constructor-injected IEnumerable receives properly decorated instances for each underlying registration after Decorate() is applied. --- .../GenerationTests.cs | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/DependencyInjection.Tests/GenerationTests.cs b/src/DependencyInjection.Tests/GenerationTests.cs index cc95ed5..5307a41 100644 --- a/src/DependencyInjection.Tests/GenerationTests.cs +++ b/src/DependencyInjection.Tests/GenerationTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.ComponentModel.Composition; using System.Linq; using System.Threading; @@ -319,6 +320,22 @@ public void DecorateMultipleRegistrations() Assert.Contains(instances, x => x.Inner is SecondMultipleDecoratedService); } + [Fact] + public void DecorateMultipleRegistrationsAsIEnumerableDependency() + { + var collection = new ServiceCollection(); + collection.AddServices(); + collection.Decorate(); + var services = collection.BuildServiceProvider(); + + var consumer = services.GetRequiredService(); + + var decorated = consumer.Services.Cast().ToList(); + Assert.Equal(2, decorated.Count); + Assert.Contains(decorated, x => x.Inner is FirstMultipleDecoratedService); + Assert.Contains(decorated, x => x.Inner is SecondMultipleDecoratedService); + } + [Fact] public void DecorateKeyedService() { @@ -594,4 +611,10 @@ public class UnattributedDecorator(IUnattributedDecorated inner, SingletonServic { public IUnattributedDecorated Inner => inner; public SingletonService Singleton => singleton; -} \ No newline at end of file +} + +[Service] +public class IEnumerableDecoratedConsumer(IEnumerable services) +{ + public IEnumerable Services => services; +}