Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,41 @@
using Microsoft.AspNetCore.Http.HttpResults;
using SharpGrip.FluentValidation.AutoValidation.Endpoints.Results;

namespace SharpGrip.FluentValidation.AutoValidation.Endpoints.Configuration
namespace SharpGrip.FluentValidation.AutoValidation.Endpoints.Configuration;

public class AutoValidationEndpointsConfiguration
{
public class AutoValidationEndpointsConfiguration
/// <summary>
/// Gets a value indicating whether the validation process should look for validators
/// registered for interfaces or base types when a validator for the concrete type is not found.
/// </summary>
public bool UseBaseTypeValidations { get; private set; }

/// <summary>
/// Holds the overridden result factory. This property is meant for infrastructure and should not be used by application code.
/// </summary>
public Type? OverriddenResultFactory { get; private set; }

/// <summary>
/// Overrides the default result factory with a custom result factory. Custom result factories are required to implement <see cref="IFluentValidationAutoValidationResultFactory" />.
/// The default result factory returns the validation errors wrapped in a <see cref="ValidationProblem" /> object.
/// </summary>
/// <see cref="FluentValidationAutoValidationDefaultResultFactory" />
/// <typeparam name="TResultFactory">The custom result factory implementing <see cref="IFluentValidationAutoValidationResultFactory" />.</typeparam>
public AutoValidationEndpointsConfiguration OverrideDefaultResultFactoryWith<TResultFactory>() where TResultFactory : IFluentValidationAutoValidationResultFactory
{
/// <summary>
/// Holds the overridden result factory. This property is meant for infrastructure and should not be used by application code.
/// </summary>
public Type? OverriddenResultFactory { get; private set; }
OverriddenResultFactory = typeof(TResultFactory);
return this;
}

/// <summary>
/// Overrides the default result factory with a custom result factory. Custom result factories are required to implement <see cref="IFluentValidationAutoValidationResultFactory"/>.
/// The default result factory returns the validation errors wrapped in a <see cref="ValidationProblem"/> object.
/// </summary>
/// <see cref="FluentValidationAutoValidationDefaultResultFactory"/>
/// <typeparam name="TResultFactory">The custom result factory implementing <see cref="IFluentValidationAutoValidationResultFactory"/>.</typeparam>
public void OverrideDefaultResultFactoryWith<TResultFactory>() where TResultFactory : IFluentValidationAutoValidationResultFactory
{
OverriddenResultFactory = typeof(TResultFactory);
}
/// <summary>
/// Enables the fallback mechanism to search for validators in the type hierarchy (interfaces and base classes)
/// if no specific validator is registered for the primary parameter type.
/// </summary>
/// <returns>The current <see cref="AutoValidationEndpointsConfiguration" /> instance for fluent chaining.</returns>
public AutoValidationEndpointsConfiguration WithBaseTypeValidations()
{
UseBaseTypeValidations = true;
return this;
}
}
Original file line number Diff line number Diff line change
@@ -1,82 +1,100 @@
using System.Threading.Tasks;
using System;
using System.Threading.Tasks;
using FluentValidation;
using FluentValidation.Results;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using SharpGrip.FluentValidation.AutoValidation.Endpoints.Configuration;
using SharpGrip.FluentValidation.AutoValidation.Endpoints.Interceptors;
using SharpGrip.FluentValidation.AutoValidation.Endpoints.Results;
using SharpGrip.FluentValidation.AutoValidation.Shared.Extensions;

namespace SharpGrip.FluentValidation.AutoValidation.Endpoints.Filters
{
public class FluentValidationAutoValidationEndpointFilter(ILogger<FluentValidationAutoValidationEndpointFilter> logger) : IEndpointFilter
public class FluentValidationAutoValidationEndpointFilter(ILogger<FluentValidationAutoValidationEndpointFilter> logger, IOptions<AutoValidationEndpointsConfiguration> options) : IEndpointFilter
{
public async ValueTask<object?> InvokeAsync(EndpointFilterInvocationContext endpointFilterInvocationContext, EndpointFilterDelegate next)
{
var serviceProvider = endpointFilterInvocationContext.HttpContext.RequestServices;

foreach (var argument in endpointFilterInvocationContext.Arguments)
{
if (argument != null && argument.GetType().IsCustomType() && serviceProvider.GetValidator(argument.GetType()) is IValidator validator)
if (argument == null) continue;

if (!argument.GetType().IsCustomType() || serviceProvider.GetValidator(argument.GetType(), options.Value.UseBaseTypeValidations) is not IValidator validator)
{
logger.LogDebug("Starting validation for argument of type '{Type}'.", argument.GetType().Name);
logger.LogDebug("Skipping argument of type '{Type}'. It's not a custom type, or no validator was found for this type.", argument.GetType().Name);
continue;
}

logger.LogDebug("Starting validation for argument of type '{Type}'.", argument.GetType().Name);

var validatorInterceptor = validator as IValidatorInterceptor;
var globalValidationInterceptor = serviceProvider.GetService<IGlobalValidationInterceptor>();
var validationResult = await ExecuteValidation(endpointFilterInvocationContext, validator, serviceProvider, argument);

IValidationContext validationContext = new ValidationContext<object>(argument);
if (!validationResult.IsValid) return CreateAInvalidResult(endpointFilterInvocationContext, argument, validationResult, serviceProvider);

if (validatorInterceptor != null)
{
logger.LogDebug("Invoking validator interceptor BeforeValidation for argument '{Argument}'.", argument.GetType().Name);
validationContext = await validatorInterceptor.BeforeValidation(endpointFilterInvocationContext, validationContext, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationContext;
}
logger.LogDebug("Validation result valid for argument '{Argument}'.", argument.GetType().Name);
}

if (globalValidationInterceptor != null)
{
logger.LogDebug("Invoking global validation interceptor BeforeValidation for argument '{Argument}'.", argument.GetType().Name);
validationContext = await globalValidationInterceptor.BeforeValidation(endpointFilterInvocationContext, validationContext, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationContext;
}
return await next(endpointFilterInvocationContext);
}

var validationResult = await validator.ValidateAsync(validationContext, endpointFilterInvocationContext.HttpContext.RequestAborted);
private object CreateAInvalidResult(EndpointFilterInvocationContext endpointFilterInvocationContext, object argument, ValidationResult validationResult, IServiceProvider serviceProvider)
{
logger.LogDebug("Validation result not valid for argument '{Argument}': {ErrorCount} validation error(s) found.", argument.GetType().Name, validationResult.Errors.Count);

if (validatorInterceptor != null)
{
logger.LogDebug("Invoking validator interceptor AfterValidation for argument '{Argument}'.", argument.GetType().Name);
validationResult = await validatorInterceptor.AfterValidation(endpointFilterInvocationContext, validationContext, validationResult, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationResult;
}
var fluentValidationAutoValidationResultFactory = serviceProvider.GetService<IFluentValidationAutoValidationResultFactory>();

if (globalValidationInterceptor != null)
{
logger.LogDebug("Invoking global validation interceptor AfterValidation for argument '{Argument}'.", argument.GetType().Name);
validationResult = await globalValidationInterceptor.AfterValidation(endpointFilterInvocationContext, validationContext, validationResult, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationResult;
}
logger.LogDebug("Creating result for path '{Path}'.", endpointFilterInvocationContext.HttpContext.Request.Path);

if (!validationResult.IsValid)
{
logger.LogDebug("Validation result not valid for argument '{Argument}': {ErrorCount} validation error(s) found.", argument.GetType().Name, validationResult.Errors.Count);
if (fluentValidationAutoValidationResultFactory != null)
{
logger.LogTrace("Creating result for path '{Path}' using a custom result factory.", endpointFilterInvocationContext.HttpContext.Request.Path);

var fluentValidationAutoValidationResultFactory = serviceProvider.GetService<IFluentValidationAutoValidationResultFactory>();
return fluentValidationAutoValidationResultFactory.CreateResult(endpointFilterInvocationContext, validationResult);
}

logger.LogDebug("Creating result for path '{Path}'.", endpointFilterInvocationContext.HttpContext.Request.Path);
logger.LogTrace("Creating result for path '{Path}' using the default result factory.", endpointFilterInvocationContext.HttpContext.Request.Path);

if (fluentValidationAutoValidationResultFactory != null)
{
logger.LogTrace("Creating result for path '{Path}' using a custom result factory.", endpointFilterInvocationContext.HttpContext.Request.Path);
return new FluentValidationAutoValidationDefaultResultFactory().CreateResult(endpointFilterInvocationContext, validationResult);
}

return fluentValidationAutoValidationResultFactory.CreateResult(endpointFilterInvocationContext, validationResult);
}
private async ValueTask<ValidationResult> ExecuteValidation(EndpointFilterInvocationContext endpointFilterInvocationContext, IValidator validator, IServiceProvider serviceProvider, object argument)
{
var validatorInterceptor = validator as IValidatorInterceptor;
var globalValidationInterceptor = serviceProvider.GetService<IGlobalValidationInterceptor>();

logger.LogTrace("Creating result for path '{Path}' using the default result factory.", endpointFilterInvocationContext.HttpContext.Request.Path);
IValidationContext validationContext = new ValidationContext<object>(argument);

return new FluentValidationAutoValidationDefaultResultFactory().CreateResult(endpointFilterInvocationContext, validationResult);
}
if (validatorInterceptor != null)
{
logger.LogDebug("Invoking validator interceptor BeforeValidation for argument '{Argument}'.", argument.GetType().Name);
validationContext = await validatorInterceptor.BeforeValidation(endpointFilterInvocationContext, validationContext, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationContext;
}

logger.LogDebug("Validation result valid for argument '{Argument}'.", argument.GetType().Name);
}
if (globalValidationInterceptor != null)
{
logger.LogDebug("Invoking global validation interceptor BeforeValidation for argument '{Argument}'.", argument.GetType().Name);
validationContext = await globalValidationInterceptor.BeforeValidation(endpointFilterInvocationContext, validationContext, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationContext;
}

return await next(endpointFilterInvocationContext);
var validationResult = await validator.ValidateAsync(validationContext, endpointFilterInvocationContext.HttpContext.RequestAborted);

if (validatorInterceptor != null)
{
logger.LogDebug("Invoking validator interceptor AfterValidation for argument '{Argument}'.", argument.GetType().Name);
validationResult = await validatorInterceptor.AfterValidation(endpointFilterInvocationContext, validationContext, validationResult, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationResult;
}

if (globalValidationInterceptor != null)
{
logger.LogDebug("Invoking global validation interceptor AfterValidation for argument '{Argument}'.", argument.GetType().Name);
validationResult = await globalValidationInterceptor.AfterValidation(endpointFilterInvocationContext, validationContext, validationResult, endpointFilterInvocationContext.HttpContext.RequestAborted) ?? validationResult;
}

return validationResult;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public async Task OnActionExecutionAsync(ActionExecutingContext actionExecutingC
var hasAutoValidateAlwaysAttribute = parameterInfo?.HasCustomAttribute<AutoValidateAlwaysAttribute>() ?? false;
var hasAutoValidateNeverAttribute = parameterInfo?.HasCustomAttribute<AutoValidateNeverAttribute>() ?? false;

if (subject != null && parameterType != null && parameterType.IsCustomType() && !hasAutoValidateNeverAttribute && (hasAutoValidateAlwaysAttribute || HasValidBindingSource(bindingSource)) && serviceProvider.GetValidator(parameterType) is IValidator validator)
if (subject != null && parameterType != null && parameterType.IsCustomType() && !hasAutoValidateNeverAttribute && (hasAutoValidateAlwaysAttribute || HasValidBindingSource(bindingSource)) && serviceProvider.GetValidator(parameterType, false) is IValidator validator)
{
logger.LogDebug("Validating parameter '{Parameter}' of type '{Type}' for action '{Action}' on controller '{Controller}'.", parameter.Name, parameterType.Name, controllerActionDescriptor.ActionName, controllerActionDescriptor.ControllerName);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,46 @@
using System;
using FluentValidation;

namespace SharpGrip.FluentValidation.AutoValidation.Shared.Extensions
namespace SharpGrip.FluentValidation.AutoValidation.Shared.Extensions;

public static class ServiceProviderExtensions
{
public static class ServiceProviderExtensions
public static object? GetValidator(this IServiceProvider serviceProvider, Type type, bool useBaseTypeValidations)
{
var validator = serviceProvider.GetService(typeof(IValidator<>).MakeGenericType(type));
if (validator is not null) return validator;
if (!useBaseTypeValidations) return null;

return GetValidatorFromBaseClasses(serviceProvider, type)
?? GetValidatorFromInterfaces(serviceProvider, type);
}

private static object? GetValidatorFromBaseClasses(IServiceProvider serviceProvider, Type type)
{
public static object? GetValidator(this IServiceProvider serviceProvider, Type type)
var baseType = type.BaseType;
while (baseType is not null && baseType != typeof(object))
{
return serviceProvider.GetService(typeof(IValidator<>).MakeGenericType(type));
var baseValidatorType = typeof(IValidator<>).MakeGenericType(baseType);
var baseValidator = serviceProvider.GetService(baseValidatorType);

if (baseValidator is not null) return baseValidator;

baseType = baseType.BaseType;
}

return null;
}

private static object? GetValidatorFromInterfaces(IServiceProvider serviceProvider, Type type)
{
foreach (var interfaceType in type.GetInterfaces())
{
var interfaceValidatorType = typeof(IValidator<>).MakeGenericType(interfaceType);
var interfaceValidator = serviceProvider.GetService(interfaceValidatorType);

if (interfaceValidator is not null) return interfaceValidator;
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@ namespace SharpGrip.FluentValidation.AutoValidation.Shared.Extensions
{
public static class TypeExtensions
{
private static readonly HashSet<Type> builtInTypes =
[
typeof(string),
typeof(decimal),
typeof(DateTime),
typeof(DateTimeOffset),
typeof(TimeSpan),
typeof(DateOnly),
typeof(TimeOnly),
typeof(Uri),
typeof(Guid),
typeof(Enum)
];

public static bool IsCustomType(this Type? type)
{
if (type == null || type.IsEnum || type.IsPrimitive)
{
return false;
}

var builtInTypes = new HashSet<Type>
{
typeof(string),
typeof(decimal),
typeof(DateTime),
typeof(DateTimeOffset),
typeof(TimeSpan),
typeof(DateOnly),
typeof(TimeOnly),
typeof(Uri),
typeof(Guid),
typeof(Enum)
};

if (builtInTypes.Contains(type))
{
return false;
Expand Down
Loading