Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 77 additions & 72 deletions MCPify/Hosting/McpOAuthAuthenticationMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using MCPify.Core;
using MCPify.Core.Auth;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Net.Http.Headers;

namespace MCPify.Hosting;
Expand All @@ -10,9 +8,6 @@ public class McpOAuthAuthenticationMiddleware
{
private readonly RequestDelegate _next;

/// <summary>
/// Key for storing token validation result in HttpContext.Items for downstream use.
/// </summary>
public const string TokenValidationResultKey = "McpTokenValidationResult";

public McpOAuthAuthenticationMiddleware(RequestDelegate next)
Expand All @@ -22,57 +17,48 @@ public McpOAuthAuthenticationMiddleware(RequestDelegate next)

public async Task InvokeAsync(HttpContext context)
{
// Skip check for metadata endpoint and other non-MCP endpoints
var path = context.Request.Path;
if (path.StartsWithSegments("/.well-known") ||
path.StartsWithSegments("/swagger") ||
path.StartsWithSegments("/health") ||
path.StartsWithSegments("/connect") || // OpenIddict or Auth endpoints
path.StartsWithSegments("/auth")) // Callback paths
path.StartsWithSegments("/connect") ||
path.StartsWithSegments("/auth"))
{
await _next(context);
return;
}

// Check if OAuth is configured
var oauthStore = context.RequestServices.GetService<OAuthConfigurationStore>();
var options = context.RequestServices.GetService<McpifyOptions>();
var oauthStore = context.RequestServices.GetService<OAuthConfigurationStore>();

var oauthConfigurations = oauthStore?.GetConfigurations().ToList() ?? [];
var validationOptions = options?.TokenValidation;
var tokenValidationEnabled = validationOptions?.EnableJwtValidation == true;

if (oauthStore == null || !oauthStore.GetConfigurations().Any())
var challengeScopes = BuildChallengeScopes(oauthConfigurations, validationOptions);
var authRequired = oauthConfigurations.Count > 0 || tokenValidationEnabled;

if (!authRequired)
{
await _next(context);
return;
}

var accessor = context.RequestServices.GetService<IMcpContextAccessor>();
var resourceUrl = GetResourceUrl(context, options);
var accessor = context.RequestServices.GetService<IMcpContextAccessor>();

// Check for Authorization header
string? authorization = context.Request.Headers[HeaderNames.Authorization];
if (string.IsNullOrEmpty(authorization) || !authorization.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
{
// No token - return 401 challenge
await WriteChallengeResponse(context, oauthStore, resourceUrl, null, null);
return;
}

// Extract token
var token = authorization.Substring("Bearer ".Length).Trim();
if (string.IsNullOrEmpty(token))
if (!TryGetBearerToken(context, out var token))
{
await WriteChallengeResponse(context, oauthStore, resourceUrl, null, null);
await WriteChallengeResponse(context, resourceUrl, challengeScopes, null, null);
return;
}

// Set token on accessor for downstream use
if (accessor != null)
{
accessor.AccessToken = token;
}

// Perform token validation if enabled
var validationOptions = options?.TokenValidation;
if (validationOptions?.EnableJwtValidation == true)
if (tokenValidationEnabled && validationOptions != null)
{
var validator = context.RequestServices.GetService<IAccessTokenValidator>();
if (validator != null)
Expand All @@ -82,31 +68,24 @@ public async Task InvokeAsync(HttpContext context)
: null;

var validationResult = await validator.ValidateAsync(token, expectedAudience, context.RequestAborted);

// Store validation result for downstream use
context.Items[TokenValidationResultKey] = validationResult;

if (!validationResult.IsValid)
{
// Token is invalid (expired, malformed, wrong audience) - return 401
await WriteInvalidTokenResponse(context, oauthStore, resourceUrl,
await WriteChallengeResponse(context, resourceUrl, challengeScopes,
validationResult.ErrorCode ?? "invalid_token",
validationResult.ErrorDescription ?? "Token validation failed");
return;
}

// Validate scopes if enabled
if (validationOptions.ValidateScopes)
{
var scopeStore = context.RequestServices.GetService<ScopeRequirementStore>();
if (scopeStore != null)
{
// Use default validation (no specific tool name available at middleware level)
var scopeResult = scopeStore.ValidateScopesForTool("*", validationResult.Scopes);

if (!scopeResult.IsValid)
{
// Token is valid but lacks required scopes - return 403
await WriteInsufficientScopeResponse(context, resourceUrl, scopeResult.MissingScopes);
return;
}
Expand Down Expand Up @@ -134,50 +113,69 @@ private static string GetResourceUrl(HttpContext context, McpifyOptions? options
return resourceUrl.TrimEnd('/');
}

private static async Task WriteChallengeResponse(
HttpContext context,
OAuthConfigurationStore oauthStore,
string resourceUrl,
string? errorCode,
string? errorDescription)
private static IReadOnlyList<string> BuildChallengeScopes(
IReadOnlyCollection<OAuth2Configuration> configurations,
TokenValidationOptions? validationOptions)
{
var metadataUrl = $"{resourceUrl}/.well-known/oauth-protected-resource";
var defaultScopes = validationOptions?.DefaultRequiredScopes;
var hasDefaultScopes = defaultScopes is { Count: > 0 };

// Collect all scopes from OAuth configurations per MCP spec
var allScopes = oauthStore.GetConfigurations()
.SelectMany(c => c.Scopes.Keys)
.Distinct(StringComparer.OrdinalIgnoreCase)
.ToList();
if (configurations.Count == 0 && !hasDefaultScopes)
{
return Array.Empty<string>();
}

context.Response.StatusCode = StatusCodes.Status401Unauthorized;
var scopes = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

foreach (var configuration in configurations)
{
foreach (var scope in configuration.Scopes.Keys)
{
scopes.Add(scope);
}
}

// Build WWW-Authenticate header per MCP Authorization spec
var wwwAuthenticate = BuildWwwAuthenticateHeader(metadataUrl, allScopes, errorCode, errorDescription);
context.Response.Headers[HeaderNames.WWWAuthenticate] = wwwAuthenticate;
if (hasDefaultScopes && defaultScopes != null)
{
foreach (var scope in defaultScopes)
{
scopes.Add(scope);
}
}

return scopes.ToList();
}

private static async Task WriteInvalidTokenResponse(
HttpContext context,
OAuthConfigurationStore oauthStore,
string resourceUrl,
string errorCode,
string errorDescription)
private static bool TryGetBearerToken(HttpContext context, out string token)
{
var metadataUrl = $"{resourceUrl}/.well-known/oauth-protected-resource";
token = string.Empty;
string? authorization = context.Request.Headers[HeaderNames.Authorization];

// Collect all scopes from OAuth configurations
var allScopes = oauthStore.GetConfigurations()
.SelectMany(c => c.Scopes.Keys)
.Distinct(StringComparer.OrdinalIgnoreCase)
.ToList();
if (string.IsNullOrEmpty(authorization) ||
!authorization.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
{
return false;
}

context.Response.StatusCode = StatusCodes.Status401Unauthorized;
token = authorization.Substring("Bearer ".Length).Trim();
return !string.IsNullOrEmpty(token);
}

var wwwAuthenticate = BuildWwwAuthenticateHeader(metadataUrl, allScopes, errorCode, errorDescription);
context.Response.Headers[HeaderNames.WWWAuthenticate] = wwwAuthenticate;
private static Task WriteChallengeResponse(
HttpContext context,
string resourceUrl,
IReadOnlyList<string> scopes,
string? errorCode,
string? errorDescription)
{
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
var metadataUrl = $"{resourceUrl}/.well-known/oauth-protected-resource";
context.Response.Headers[HeaderNames.WWWAuthenticate] =
BuildWwwAuthenticateHeader(metadataUrl, scopes, errorCode, errorDescription);
return Task.CompletedTask;
}

private static async Task WriteInsufficientScopeResponse(
private static Task WriteInsufficientScopeResponse(
HttpContext context,
string resourceUrl,
IReadOnlyList<string> requiredScopes)
Expand All @@ -186,7 +184,6 @@ private static async Task WriteInsufficientScopeResponse(

context.Response.StatusCode = StatusCodes.Status403Forbidden;

// Build WWW-Authenticate header for insufficient_scope per RFC 6750 Section 3.1
var parts = new List<string>
{
"Bearer",
Expand All @@ -201,6 +198,7 @@ private static async Task WriteInsufficientScopeResponse(
}

context.Response.Headers[HeaderNames.WWWAuthenticate] = string.Join(", ", parts);
return Task.CompletedTask;
}

private static string BuildWwwAuthenticateHeader(
Expand All @@ -209,7 +207,15 @@ private static string BuildWwwAuthenticateHeader(
string? errorCode,
string? errorDescription)
{
var parts = new List<string> { $"Bearer resource_metadata=\"{metadataUrl}\"" };
if (string.IsNullOrEmpty(errorCode) && string.IsNullOrEmpty(errorDescription) && scopes.Count == 0)
{
return $"Bearer resource_metadata=\"{metadataUrl}\"";
}

var parts = new List<string>(4)
{
$"Bearer resource_metadata=\"{metadataUrl}\""
};

if (!string.IsNullOrEmpty(errorCode))
{
Expand All @@ -218,7 +224,6 @@ private static string BuildWwwAuthenticateHeader(

if (!string.IsNullOrEmpty(errorDescription))
{
// Escape quotes in description
var escapedDescription = errorDescription.Replace("\"", "\\\"");
parts.Add($"error_description=\"{escapedDescription}\"");
}
Expand Down
6 changes: 6 additions & 0 deletions Sample/Extensions/DemoServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ public static IServiceCollection AddDemoMcpify(this IServiceCollection services,
options.Transport = transport;
options.ResourceUrlOverride = baseUrl;

options.TokenValidation = new TokenValidationOptions
{
EnableJwtValidation = true,
ValidateAudience = true
};

// Expose the local API (which is now the "Real" API)
options.LocalEndpoints = new()
{
Expand Down
18 changes: 9 additions & 9 deletions Sample/MCPify.Sample.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net8.0'">
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.0" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.0" />
<PackageReference Include="Microsoft.IdentityModel.Protocols.OpenIdConnect" Version="7.5.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.21" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.23" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.23" />
<PackageReference Include="Microsoft.IdentityModel.Protocols.OpenIdConnect" Version="7.5.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.23" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net9.0'">
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="9.0.0" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="9.0.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="9.0.10" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="9.0.12" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="9.0.12" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="9.0.12" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net10.0'">
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="10.0.0" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="9.0.0" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="9.0.12" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="10.0.0" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="OpenIddict.AspNetCore" Version="7.2.0" />
<PackageReference Include="OpenIddict.EntityFrameworkCore" Version="7.2.0" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="7.0.0" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="7.3.2" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using System.Net;
using System.Net.Http.Headers;
using System.Text;
using MCPify.Core;
using MCPify.Core.Auth;
using MCPify.Hosting;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;

namespace MCPify.Tests.Integration;

public class OAuthChallengeTokenValidationTests
{
[Fact]
public async Task PostWithoutSession_ReturnsUnauthorizedChallenge_WhenTokenValidationEnabled()
{
using var host = await new HostBuilder()
.ConfigureWebHost(webBuilder =>
{
webBuilder
.UseTestServer()
.ConfigureServices(services =>
{
services.AddLogging();
services.AddRouting();
services.AddMcpify(options =>
{
options.Transport = McpTransportType.Http;
options.TokenValidation = new TokenValidationOptions
{
EnableJwtValidation = true,
ValidateAudience = true
};
});
})
.Configure(app =>
{
app.UseRouting();
app.UseMcpifyContext();
app.UseMcpifyOAuth();
app.UseEndpoints(endpoints =>
{
endpoints.MapMcpifyEndpoint();
});
});
})
.StartAsync();

var options = host.Services.GetRequiredService<McpifyOptions>();
Assert.True(options.TokenValidation?.EnableJwtValidation, "Token validation should be enabled");
var validationOptions = host.Services.GetRequiredService<TokenValidationOptions>();
Assert.True(validationOptions.EnableJwtValidation, "TokenValidationOptions from DI should have EnableJwtValidation true");

var client = host.GetTestClient();

using var request = new HttpRequestMessage(HttpMethod.Post, "/")
{
Content = new StringContent("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"ping\",\"params\":{}}", Encoding.UTF8, "application/json")
};
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));

var response = await client.SendAsync(request);
var body = await response.Content.ReadAsStringAsync();

var authenticateHeader = string.Join(" | ", response.Headers.WwwAuthenticate.Select(h => h.ToString()));
Assert.True(response.StatusCode == HttpStatusCode.Unauthorized,
$"Expected 401 challenge, got {(int)response.StatusCode} {response.StatusCode}. Headers: {authenticateHeader}. Body: {body}");

Assert.Contains(response.Headers.WwwAuthenticate, header =>
string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase));
}
}