diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java index e976328623..48f491ccad 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java @@ -19,6 +19,7 @@ import java.text.ParseException; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; @@ -44,6 +45,12 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private final AtomicReference> cachedJWKSet = new AtomicReference<>(Mono.empty()); + /** + * In-flight JWK set fetch request, used to coalesce concurrent fetches into a single + * HTTP call. + */ + private final AtomicReference<@Nullable Mono> inflightRequest = new AtomicReference<>(); + /** * The cached JWK set URL. */ @@ -101,24 +108,23 @@ private Mono> get(JWKSelector jwkSelector, JWKSet jwkSet) { } /** - * Updates the cached JWK set from the configured URL. + * Updates the cached JWK set from the configured URL. Concurrent calls are coalesced + * into a single HTTP request to prevent thundering herd during cold start. * @return The updated JWK set. * @throws RemoteKeySourceException If JWK retrieval failed. */ private Mono getJWKSet() { - // @formatter:off - return this.jwkSetUrlProvider - .flatMap((jwkSetURL) -> this.webClient.get() - .uri(jwkSetURL) - .retrieve() - .bodyToMono(String.class) - ) - .map(this::parse) - .doOnNext((jwkSet) -> this.cachedJWKSet - .set(Mono.just(jwkSet)) - ) - .cache(); - // @formatter:on + Mono fetch = Mono.defer(() -> this.jwkSetUrlProvider + .flatMap((jwkSetURL) -> this.webClient.get().uri(jwkSetURL).retrieve().bodyToMono(String.class)) + .map(this::parse) + .doOnNext((jwkSet) -> { + this.cachedJWKSet.set(Mono.just(jwkSet)); + this.inflightRequest.set(null); + }) + .doOnError((ex) -> this.inflightRequest.set(null)) + .doOnCancel(() -> this.inflightRequest.set(null)) + .switchIfEmpty(Mono.fromRunnable(() -> this.inflightRequest.set(null)))).cache(); + return Objects.requireNonNull(this.inflightRequest.updateAndGet((v) -> (v != null) ? v : fetch)); } private JWKSet parse(String body) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java index e00b76047f..719f8ae900 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java @@ -16,8 +16,10 @@ package org.springframework.security.oauth2.jwt; +import java.time.Duration; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import com.nimbusds.jose.jwk.JWK; @@ -32,7 +34,9 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import org.springframework.web.reactive.function.client.WebClientResponseException; @@ -166,6 +170,44 @@ public void getWhenNoMatchAndKeyIdMatchThenEmpty() { assertThat(this.source.get(this.selector).block()).isEmpty(); } + @Test + public void getWhenConcurrentRequestsThenSingleFetch() { + // given + given(this.matcher.matches(any())).willReturn(true); + int concurrentRequests = 10; + for (int i = 0; i < concurrentRequests; i++) { + this.server.enqueue(new MockResponse().setBody(this.keys).setBodyDelay(100, TimeUnit.MILLISECONDS)); + } + + // when + List> results = Flux.range(0, concurrentRequests) + .flatMap((i) -> this.source.get(this.selector).subscribeOn(Schedulers.parallel()), concurrentRequests) + .collectList() + .block(Duration.ofSeconds(5)); + + // then + assertThat(results).hasSize(concurrentRequests); + assertThat(this.server.getRequestCount()).isEqualTo(1); + } + + @Test + public void getWhenEmptyResponseThenNextCallSucceeds() { + // given + given(this.matcher.matches(any())).willReturn(true); + this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(this.mockStringSupplier)); + // first call: supplier returns null URL, causing empty Mono from + // jwkSetUrlProvider + willReturn(null).given(this.mockStringSupplier).get(); + + // when: first call completes empty + List firstResult = this.source.get(this.selector).block(Duration.ofSeconds(5)); + + // then: inflight is cleared and second call can succeed + willReturn(this.server.url("/").toString()).given(this.mockStringSupplier).get(); + List secondResult = this.source.get(this.selector).block(Duration.ofSeconds(5)); + assertThat(secondResult).isNotEmpty(); + } + @Test public void getShouldRecoverAndReturnKeysAfterErrorCase() { given(this.matcher.matches(any())).willReturn(true);