Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.text.ParseException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -44,6 +45,12 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
*/
private final AtomicReference<Mono<JWKSet>> cachedJWKSet = new AtomicReference<>(Mono.empty());

/**
* In-flight JWK set fetch request, used to coalesce concurrent fetches into a single
* HTTP call.
*/
private final AtomicReference<@Nullable Mono<JWKSet>> inflightRequest = new AtomicReference<>();

/**
* The cached JWK set URL.
*/
Expand Down Expand Up @@ -101,24 +108,23 @@ private Mono<List<JWK>> get(JWKSelector jwkSelector, JWKSet jwkSet) {
}

/**
* Updates the cached JWK set from the configured URL.
* Updates the cached JWK set from the configured URL. Concurrent calls are coalesced
* into a single HTTP request to prevent thundering herd during cold start.
* @return The updated JWK set.
* @throws RemoteKeySourceException If JWK retrieval failed.
*/
private Mono<JWKSet> getJWKSet() {
// @formatter:off
return this.jwkSetUrlProvider
.flatMap((jwkSetURL) -> this.webClient.get()
.uri(jwkSetURL)
.retrieve()
.bodyToMono(String.class)
)
.map(this::parse)
.doOnNext((jwkSet) -> this.cachedJWKSet
.set(Mono.just(jwkSet))
)
.cache();
// @formatter:on
Mono<JWKSet> fetch = Mono.defer(() -> this.jwkSetUrlProvider
.flatMap((jwkSetURL) -> this.webClient.get().uri(jwkSetURL).retrieve().bodyToMono(String.class))
.map(this::parse)
.doOnNext((jwkSet) -> {
this.cachedJWKSet.set(Mono.just(jwkSet));
this.inflightRequest.set(null);
})
.doOnError((ex) -> this.inflightRequest.set(null))
.doOnCancel(() -> this.inflightRequest.set(null))
.switchIfEmpty(Mono.fromRunnable(() -> this.inflightRequest.set(null)))).cache();
return Objects.requireNonNull(this.inflightRequest.updateAndGet((v) -> (v != null) ? v : fetch));
}

private JWKSet parse(String body) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package org.springframework.security.oauth2.jwt;

import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import com.nimbusds.jose.jwk.JWK;
Expand All @@ -32,7 +34,9 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.web.reactive.function.client.WebClientResponseException;

Expand Down Expand Up @@ -166,6 +170,44 @@ public void getWhenNoMatchAndKeyIdMatchThenEmpty() {
assertThat(this.source.get(this.selector).block()).isEmpty();
}

@Test
public void getWhenConcurrentRequestsThenSingleFetch() {
// given
given(this.matcher.matches(any())).willReturn(true);
int concurrentRequests = 10;
for (int i = 0; i < concurrentRequests; i++) {
this.server.enqueue(new MockResponse().setBody(this.keys).setBodyDelay(100, TimeUnit.MILLISECONDS));
}

// when
List<List<JWK>> results = Flux.range(0, concurrentRequests)
.flatMap((i) -> this.source.get(this.selector).subscribeOn(Schedulers.parallel()), concurrentRequests)
.collectList()
.block(Duration.ofSeconds(5));

// then
assertThat(results).hasSize(concurrentRequests);
assertThat(this.server.getRequestCount()).isEqualTo(1);
}

@Test
public void getWhenEmptyResponseThenNextCallSucceeds() {
// given
given(this.matcher.matches(any())).willReturn(true);
this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(this.mockStringSupplier));
// first call: supplier returns null URL, causing empty Mono from
// jwkSetUrlProvider
willReturn(null).given(this.mockStringSupplier).get();

// when: first call completes empty
List<JWK> firstResult = this.source.get(this.selector).block(Duration.ofSeconds(5));

// then: inflight is cleared and second call can succeed
willReturn(this.server.url("/").toString()).given(this.mockStringSupplier).get();
List<JWK> secondResult = this.source.get(this.selector).block(Duration.ofSeconds(5));
assertThat(secondResult).isNotEmpty();
}

@Test
public void getShouldRecoverAndReturnKeysAfterErrorCase() {
given(this.matcher.matches(any())).willReturn(true);
Expand Down