diff --git a/MCPify/Hosting/McpOAuthAuthenticationMiddleware.cs b/MCPify/Hosting/McpOAuthAuthenticationMiddleware.cs
index b1d59a0..0d50ef2 100644
--- a/MCPify/Hosting/McpOAuthAuthenticationMiddleware.cs
+++ b/MCPify/Hosting/McpOAuthAuthenticationMiddleware.cs
@@ -1,7 +1,5 @@
using MCPify.Core;
using MCPify.Core.Auth;
-using Microsoft.AspNetCore.Http;
-using Microsoft.Extensions.DependencyInjection;
using Microsoft.Net.Http.Headers;
namespace MCPify.Hosting;
@@ -10,9 +8,6 @@ public class McpOAuthAuthenticationMiddleware
{
private readonly RequestDelegate _next;
- ///
- /// Key for storing token validation result in HttpContext.Items for downstream use.
- ///
public const string TokenValidationResultKey = "McpTokenValidationResult";
public McpOAuthAuthenticationMiddleware(RequestDelegate next)
@@ -22,57 +17,48 @@ public McpOAuthAuthenticationMiddleware(RequestDelegate next)
public async Task InvokeAsync(HttpContext context)
{
- // Skip check for metadata endpoint and other non-MCP endpoints
var path = context.Request.Path;
if (path.StartsWithSegments("/.well-known") ||
path.StartsWithSegments("/swagger") ||
path.StartsWithSegments("/health") ||
- path.StartsWithSegments("/connect") || // OpenIddict or Auth endpoints
- path.StartsWithSegments("/auth")) // Callback paths
+ path.StartsWithSegments("/connect") ||
+ path.StartsWithSegments("/auth"))
{
await _next(context);
return;
}
- // Check if OAuth is configured
- var oauthStore = context.RequestServices.GetService();
var options = context.RequestServices.GetService();
+ var oauthStore = context.RequestServices.GetService();
+
+ var oauthConfigurations = oauthStore?.GetConfigurations().ToList() ?? [];
+ var validationOptions = options?.TokenValidation;
+ var tokenValidationEnabled = validationOptions?.EnableJwtValidation == true;
- if (oauthStore == null || !oauthStore.GetConfigurations().Any())
+ var challengeScopes = BuildChallengeScopes(oauthConfigurations, validationOptions);
+ var authRequired = oauthConfigurations.Count > 0 || tokenValidationEnabled;
+
+ if (!authRequired)
{
await _next(context);
return;
}
- var accessor = context.RequestServices.GetService();
var resourceUrl = GetResourceUrl(context, options);
+ var accessor = context.RequestServices.GetService();
- // Check for Authorization header
- string? authorization = context.Request.Headers[HeaderNames.Authorization];
- if (string.IsNullOrEmpty(authorization) || !authorization.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
- {
- // No token - return 401 challenge
- await WriteChallengeResponse(context, oauthStore, resourceUrl, null, null);
- return;
- }
-
- // Extract token
- var token = authorization.Substring("Bearer ".Length).Trim();
- if (string.IsNullOrEmpty(token))
+ if (!TryGetBearerToken(context, out var token))
{
- await WriteChallengeResponse(context, oauthStore, resourceUrl, null, null);
+ await WriteChallengeResponse(context, resourceUrl, challengeScopes, null, null);
return;
}
- // Set token on accessor for downstream use
if (accessor != null)
{
accessor.AccessToken = token;
}
- // Perform token validation if enabled
- var validationOptions = options?.TokenValidation;
- if (validationOptions?.EnableJwtValidation == true)
+ if (tokenValidationEnabled && validationOptions != null)
{
var validator = context.RequestServices.GetService();
if (validator != null)
@@ -82,31 +68,24 @@ public async Task InvokeAsync(HttpContext context)
: null;
var validationResult = await validator.ValidateAsync(token, expectedAudience, context.RequestAborted);
-
- // Store validation result for downstream use
context.Items[TokenValidationResultKey] = validationResult;
if (!validationResult.IsValid)
{
- // Token is invalid (expired, malformed, wrong audience) - return 401
- await WriteInvalidTokenResponse(context, oauthStore, resourceUrl,
+ await WriteChallengeResponse(context, resourceUrl, challengeScopes,
validationResult.ErrorCode ?? "invalid_token",
validationResult.ErrorDescription ?? "Token validation failed");
return;
}
- // Validate scopes if enabled
if (validationOptions.ValidateScopes)
{
var scopeStore = context.RequestServices.GetService();
if (scopeStore != null)
{
- // Use default validation (no specific tool name available at middleware level)
var scopeResult = scopeStore.ValidateScopesForTool("*", validationResult.Scopes);
-
if (!scopeResult.IsValid)
{
- // Token is valid but lacks required scopes - return 403
await WriteInsufficientScopeResponse(context, resourceUrl, scopeResult.MissingScopes);
return;
}
@@ -134,50 +113,69 @@ private static string GetResourceUrl(HttpContext context, McpifyOptions? options
return resourceUrl.TrimEnd('/');
}
- private static async Task WriteChallengeResponse(
- HttpContext context,
- OAuthConfigurationStore oauthStore,
- string resourceUrl,
- string? errorCode,
- string? errorDescription)
+ private static IReadOnlyList BuildChallengeScopes(
+ IReadOnlyCollection configurations,
+ TokenValidationOptions? validationOptions)
{
- var metadataUrl = $"{resourceUrl}/.well-known/oauth-protected-resource";
+ var defaultScopes = validationOptions?.DefaultRequiredScopes;
+ var hasDefaultScopes = defaultScopes is { Count: > 0 };
- // Collect all scopes from OAuth configurations per MCP spec
- var allScopes = oauthStore.GetConfigurations()
- .SelectMany(c => c.Scopes.Keys)
- .Distinct(StringComparer.OrdinalIgnoreCase)
- .ToList();
+ if (configurations.Count == 0 && !hasDefaultScopes)
+ {
+ return Array.Empty();
+ }
- context.Response.StatusCode = StatusCodes.Status401Unauthorized;
+ var scopes = new HashSet(StringComparer.OrdinalIgnoreCase);
+
+ foreach (var configuration in configurations)
+ {
+ foreach (var scope in configuration.Scopes.Keys)
+ {
+ scopes.Add(scope);
+ }
+ }
- // Build WWW-Authenticate header per MCP Authorization spec
- var wwwAuthenticate = BuildWwwAuthenticateHeader(metadataUrl, allScopes, errorCode, errorDescription);
- context.Response.Headers[HeaderNames.WWWAuthenticate] = wwwAuthenticate;
+ if (hasDefaultScopes && defaultScopes != null)
+ {
+ foreach (var scope in defaultScopes)
+ {
+ scopes.Add(scope);
+ }
+ }
+
+ return scopes.ToList();
}
- private static async Task WriteInvalidTokenResponse(
- HttpContext context,
- OAuthConfigurationStore oauthStore,
- string resourceUrl,
- string errorCode,
- string errorDescription)
+ private static bool TryGetBearerToken(HttpContext context, out string token)
{
- var metadataUrl = $"{resourceUrl}/.well-known/oauth-protected-resource";
+ token = string.Empty;
+ string? authorization = context.Request.Headers[HeaderNames.Authorization];
- // Collect all scopes from OAuth configurations
- var allScopes = oauthStore.GetConfigurations()
- .SelectMany(c => c.Scopes.Keys)
- .Distinct(StringComparer.OrdinalIgnoreCase)
- .ToList();
+ if (string.IsNullOrEmpty(authorization) ||
+ !authorization.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
+ {
+ return false;
+ }
- context.Response.StatusCode = StatusCodes.Status401Unauthorized;
+ token = authorization.Substring("Bearer ".Length).Trim();
+ return !string.IsNullOrEmpty(token);
+ }
- var wwwAuthenticate = BuildWwwAuthenticateHeader(metadataUrl, allScopes, errorCode, errorDescription);
- context.Response.Headers[HeaderNames.WWWAuthenticate] = wwwAuthenticate;
+ private static Task WriteChallengeResponse(
+ HttpContext context,
+ string resourceUrl,
+ IReadOnlyList scopes,
+ string? errorCode,
+ string? errorDescription)
+ {
+ context.Response.StatusCode = StatusCodes.Status401Unauthorized;
+ var metadataUrl = $"{resourceUrl}/.well-known/oauth-protected-resource";
+ context.Response.Headers[HeaderNames.WWWAuthenticate] =
+ BuildWwwAuthenticateHeader(metadataUrl, scopes, errorCode, errorDescription);
+ return Task.CompletedTask;
}
- private static async Task WriteInsufficientScopeResponse(
+ private static Task WriteInsufficientScopeResponse(
HttpContext context,
string resourceUrl,
IReadOnlyList requiredScopes)
@@ -186,7 +184,6 @@ private static async Task WriteInsufficientScopeResponse(
context.Response.StatusCode = StatusCodes.Status403Forbidden;
- // Build WWW-Authenticate header for insufficient_scope per RFC 6750 Section 3.1
var parts = new List
{
"Bearer",
@@ -201,6 +198,7 @@ private static async Task WriteInsufficientScopeResponse(
}
context.Response.Headers[HeaderNames.WWWAuthenticate] = string.Join(", ", parts);
+ return Task.CompletedTask;
}
private static string BuildWwwAuthenticateHeader(
@@ -209,7 +207,15 @@ private static string BuildWwwAuthenticateHeader(
string? errorCode,
string? errorDescription)
{
- var parts = new List { $"Bearer resource_metadata=\"{metadataUrl}\"" };
+ if (string.IsNullOrEmpty(errorCode) && string.IsNullOrEmpty(errorDescription) && scopes.Count == 0)
+ {
+ return $"Bearer resource_metadata=\"{metadataUrl}\"";
+ }
+
+ var parts = new List(4)
+ {
+ $"Bearer resource_metadata=\"{metadataUrl}\""
+ };
if (!string.IsNullOrEmpty(errorCode))
{
@@ -218,7 +224,6 @@ private static string BuildWwwAuthenticateHeader(
if (!string.IsNullOrEmpty(errorDescription))
{
- // Escape quotes in description
var escapedDescription = errorDescription.Replace("\"", "\\\"");
parts.Add($"error_description=\"{escapedDescription}\"");
}
diff --git a/Sample/Extensions/DemoServiceExtensions.cs b/Sample/Extensions/DemoServiceExtensions.cs
index 8f53b4d..029a07f 100644
--- a/Sample/Extensions/DemoServiceExtensions.cs
+++ b/Sample/Extensions/DemoServiceExtensions.cs
@@ -143,6 +143,12 @@ public static IServiceCollection AddDemoMcpify(this IServiceCollection services,
options.Transport = transport;
options.ResourceUrlOverride = baseUrl;
+ options.TokenValidation = new TokenValidationOptions
+ {
+ EnableJwtValidation = true,
+ ValidateAudience = true
+ };
+
// Expose the local API (which is now the "Real" API)
options.LocalEndpoints = new()
{
diff --git a/Sample/MCPify.Sample.csproj b/Sample/MCPify.Sample.csproj
index 96ff7af..83159b9 100644
--- a/Sample/MCPify.Sample.csproj
+++ b/Sample/MCPify.Sample.csproj
@@ -17,28 +17,28 @@
-
-
-
-
+
+
+
+
-
-
-
+
+
+
-
+
-
+
diff --git a/Tests/MCPify.Tests/Integration/OAuthChallengeTokenValidationTests.cs b/Tests/MCPify.Tests/Integration/OAuthChallengeTokenValidationTests.cs
new file mode 100644
index 0000000..4b5a033
--- /dev/null
+++ b/Tests/MCPify.Tests/Integration/OAuthChallengeTokenValidationTests.cs
@@ -0,0 +1,76 @@
+using System.Net;
+using System.Net.Http.Headers;
+using System.Text;
+using MCPify.Core;
+using MCPify.Core.Auth;
+using MCPify.Hosting;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.Hosting;
+using Microsoft.AspNetCore.TestHost;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Hosting;
+
+namespace MCPify.Tests.Integration;
+
+public class OAuthChallengeTokenValidationTests
+{
+ [Fact]
+ public async Task PostWithoutSession_ReturnsUnauthorizedChallenge_WhenTokenValidationEnabled()
+ {
+ using var host = await new HostBuilder()
+ .ConfigureWebHost(webBuilder =>
+ {
+ webBuilder
+ .UseTestServer()
+ .ConfigureServices(services =>
+ {
+ services.AddLogging();
+ services.AddRouting();
+ services.AddMcpify(options =>
+ {
+ options.Transport = McpTransportType.Http;
+ options.TokenValidation = new TokenValidationOptions
+ {
+ EnableJwtValidation = true,
+ ValidateAudience = true
+ };
+ });
+ })
+ .Configure(app =>
+ {
+ app.UseRouting();
+ app.UseMcpifyContext();
+ app.UseMcpifyOAuth();
+ app.UseEndpoints(endpoints =>
+ {
+ endpoints.MapMcpifyEndpoint();
+ });
+ });
+ })
+ .StartAsync();
+
+ var options = host.Services.GetRequiredService();
+ Assert.True(options.TokenValidation?.EnableJwtValidation, "Token validation should be enabled");
+ var validationOptions = host.Services.GetRequiredService();
+ Assert.True(validationOptions.EnableJwtValidation, "TokenValidationOptions from DI should have EnableJwtValidation true");
+
+ var client = host.GetTestClient();
+
+ using var request = new HttpRequestMessage(HttpMethod.Post, "/")
+ {
+ Content = new StringContent("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"ping\",\"params\":{}}", Encoding.UTF8, "application/json")
+ };
+ request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
+ request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
+
+ var response = await client.SendAsync(request);
+ var body = await response.Content.ReadAsStringAsync();
+
+ var authenticateHeader = string.Join(" | ", response.Headers.WwwAuthenticate.Select(h => h.ToString()));
+ Assert.True(response.StatusCode == HttpStatusCode.Unauthorized,
+ $"Expected 401 challenge, got {(int)response.StatusCode} {response.StatusCode}. Headers: {authenticateHeader}. Body: {body}");
+
+ Assert.Contains(response.Headers.WwwAuthenticate, header =>
+ string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase));
+ }
+}