Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ JwkProvider provider = new JwkProviderBuilder("https://samples.auth0.com/")
.build();
```

### Configure SSL/TLS settings

A custom `SSLSocketFactory` can be configured for HTTPS connections to the JWKS endpoint. This is useful for environments that require a specific TLS version, custom trust stores, or mutual TLS (mTLS).

```java
// Configure a specific TLS version
SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
sslContext.init(null, null, new SecureRandom());

JwkProvider provider = new JwkProviderBuilder("https://samples.auth0.com/")
.sslSocketFactory(sslContext.getSocketFactory())
.build();
```

When not configured, the JVM's default SSL settings will be used.

See the [JwkProviderBuilder JavaDocs](https://javadoc.io/doc/com.auth0/jwks-rsa/latest/com/auth0/jwk/JwkProviderBuilder.html) for all available configurations.

## Error handling
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ ext {
}

dependencies {
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version:'2.18.6'
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version:'2.21.2'
implementation (group: 'com.google.guava', name: 'guava', version:'32.1.2-jre') {
// needed due to https://github.com/google/guava/issues/6654
exclude group: "org.mockito", module: "mockito-core"
Expand Down
17 changes: 16 additions & 1 deletion src/main/java/com/auth0/jwk/JwkProviderBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLSocketFactory;

import static com.auth0.jwk.UrlJwkProvider.urlForDomain;

Expand All @@ -24,6 +25,7 @@ public class JwkProviderBuilder {
private BucketImpl bucket;
private boolean rateLimited;
private Map<String, String> headers;
private SSLSocketFactory sslSocketFactory;

/**
* Creates a new Builder with the given URL where to load the jwks from.
Expand Down Expand Up @@ -166,13 +168,26 @@ public JwkProviderBuilder headers(Map<String, String> headers) {
return this;
}

/**
* Sets a custom {@link SSLSocketFactory} for HTTPS connections to the JWKS endpoint.
* This allows configuration of TLS version, cipher suites, and custom certificate validation.
* When not set, the JVM default SSL configuration will be used.
*
* @param sslSocketFactory the SSL socket factory to use for HTTPS connections (null for JVM default)
* @return the builder
*/
public JwkProviderBuilder sslSocketFactory(SSLSocketFactory sslSocketFactory) {
this.sslSocketFactory = sslSocketFactory;
return this;
}

/**
* Creates a {@link JwkProvider}
*
* @return a newly created {@link JwkProvider}
*/
public JwkProvider build() {
JwkProvider urlProvider = new UrlJwkProvider(url, connectTimeout, readTimeout, proxy, headers);
JwkProvider urlProvider = new UrlJwkProvider(url, connectTimeout, readTimeout, proxy, headers, sslSocketFactory);
if (this.rateLimited) {
urlProvider = new RateLimitedJwkProvider(urlProvider, bucket);
}
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/com/auth0/jwk/UrlJwkProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.net.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory;

/**
* Jwk provider that loads them from a {@link URL}
Expand All @@ -25,6 +27,7 @@ public class UrlJwkProvider implements JwkProvider {
final Map<String, String> headers;
final Integer connectTimeout;
final Integer readTimeout;
final SSLSocketFactory sslSocketFactory;

private final ObjectReader reader;

Expand Down Expand Up @@ -59,6 +62,20 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox
* @param headers a map of request header keys to values to send on the request. Default is "Accept: application/json".
*/
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy, Map<String, String> headers) {
this(url, connectTimeout, readTimeout, proxy, headers, null);
}

/**
* Creates a provider that loads from the given URL using custom request headers and SSL configuration.
*
* @param url to load the jwks
* @param connectTimeout connection timeout in milliseconds (default is null)
* @param readTimeout read timeout in milliseconds (default is null)
* @param proxy proxy server to use when making the connection (default is null)
* @param headers a map of request header keys to values to send on the request. Default is "Accept: application/json".
* @param sslSocketFactory the SSL socket factory to use for HTTPS connections (null for JVM default)
*/
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Proxy proxy, Map<String, String> headers, SSLSocketFactory sslSocketFactory) {
Util.checkArgument(url != null, "A non-null url is required");
Util.checkArgument(connectTimeout == null || connectTimeout >= 0, "Invalid connect timeout value '" + connectTimeout + "'. Must be a non-negative integer.");
Util.checkArgument(readTimeout == null || readTimeout >= 0, "Invalid read timeout value '" + readTimeout + "'. Must be a non-negative integer.");
Expand All @@ -67,6 +84,7 @@ public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout, Prox
this.proxy = proxy;
this.connectTimeout = connectTimeout;
this.readTimeout = readTimeout;
this.sslSocketFactory = sslSocketFactory;
this.reader = new ObjectMapper().readerFor(Map.class);

this.headers = (headers == null) ?
Expand Down Expand Up @@ -126,6 +144,9 @@ static URL urlForDomain(String domain) {
private Map<String, Object> getJwks() throws SigningKeyNotFoundException {
try {
final URLConnection c = (proxy == null) ? this.url.openConnection() : this.url.openConnection(proxy);
if (c instanceof HttpsURLConnection && sslSocketFactory != null) {
((HttpsURLConnection) c).setSSLSocketFactory(sslSocketFactory);
}
if (connectTimeout != null) {
c.setConnectTimeout(connectTimeout);
}
Expand Down
28 changes: 28 additions & 0 deletions src/test/java/com/auth0/jwk/JwkProviderBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLSocketFactory;

import static com.auth0.jwk.UrlJwkProvider.WELL_KNOWN_JWKS_PATH;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.mockito.Mockito.mock;

public class JwkProviderBuilderTest {

Expand Down Expand Up @@ -195,4 +197,30 @@ public void shouldCreateForUrlWithCustomHeaders() throws Exception {
UrlJwkProvider urlJwkProvider = (UrlJwkProvider) provider;
assertThat(urlJwkProvider.headers, equalTo(headers));
}

@Test
public void shouldCreateForUrlWithSSLSocketFactory() throws Exception {
URL url = new URL(normalizedDomain + WELL_KNOWN_JWKS_PATH);
SSLSocketFactory sslSocketFactory = mock(SSLSocketFactory.class);
JwkProvider provider = new JwkProviderBuilder(url)
.sslSocketFactory(sslSocketFactory)
.rateLimited(false)
.cached(false)
.build();
assertThat(provider, notNullValue());
UrlJwkProvider urlJwkProvider = (UrlJwkProvider) provider;
assertThat(urlJwkProvider.sslSocketFactory, equalTo(sslSocketFactory));
}

@Test
public void shouldDefaultSSLSocketFactoryToNull() throws Exception {
URL url = new URL(normalizedDomain + WELL_KNOWN_JWKS_PATH);
JwkProvider provider = new JwkProviderBuilder(url)
.rateLimited(false)
.cached(false)
.build();
assertThat(provider, notNullValue());
UrlJwkProvider urlJwkProvider = (UrlJwkProvider) provider;
assertThat(urlJwkProvider.sslSocketFactory, is(nullValue()));
}
}
35 changes: 35 additions & 0 deletions src/test/java/com/auth0/jwk/UrlJwkProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory;

import static com.auth0.jwk.UrlJwkProvider.WELL_KNOWN_JWKS_PATH;
import static org.hamcrest.Matchers.*;
Expand Down Expand Up @@ -404,4 +406,37 @@ public void shouldFetchIfCacheIsNull() throws Exception {
verify(provider, atLeastOnce()).getAll(); // Should definitely be called
}

@Test
public void shouldConfigureSSLSocketFactoryForHttpsConnection() throws Exception {
HttpsURLConnection httpsUrlConnection = mock(HttpsURLConnection.class);
when(httpsUrlConnection.getInputStream()).thenAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
return getClass().getResourceAsStream("/jwks.json");
}
});

SSLSocketFactory sslSocketFactory = mock(SSLSocketFactory.class);
URL url = getClass().getResource("/jwks.json");
UrlJwkProvider provider = new UrlJwkProvider(url, null, null, null, null, sslSocketFactory);

assertThat(provider.sslSocketFactory, is(sslSocketFactory));
}

@Test
public void shouldDefaultSSLSocketFactoryToNull() throws Exception {
URL url = getClass().getResource("/jwks.json");
UrlJwkProvider provider = new UrlJwkProvider(url);

assertThat(provider.sslSocketFactory, is(nullValue()));
}

@Test
public void shouldDefaultSSLSocketFactoryToNullWith5ArgConstructor() throws Exception {
URL url = getClass().getResource("/jwks.json");
UrlJwkProvider provider = new UrlJwkProvider(url, null, null, null, null);

assertThat(provider.sslSocketFactory, is(nullValue()));
}

}
Loading