Skip to content

Commit

Permalink
Optional support for connection / read timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
skjolber committed Jun 5, 2018
1 parent aa771ad commit e7fbffa
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/main/java/com/auth0/jwk/UrlJwkProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.util.List;
import java.util.Map;

Expand All @@ -27,14 +28,31 @@ public class UrlJwkProvider implements JwkProvider {
static final String WELL_KNOWN_JWKS_PATH = "/.well-known/jwks.json";

final URL url;
private final Integer connectTimeout;
private final Integer readTimeout;

/**
* Creates a provider that loads from the given URL
* @param url to load the jwks
*/
public UrlJwkProvider(URL url) {
this(url, null, null);
}

/**
* Creates a provider that loads from the given URL
* @param url to load the jwks
* @param connectTimeout connection timeout in milliseconds (null for default)
* @param readTimeout read timeout in milliseconds (null for default)
*/
public UrlJwkProvider(URL url, Integer connectTimeout, Integer readTimeout) {
checkArgument(url != null, "A non-null url is required");
checkArgument(connectTimeout == null || connectTimeout >= 0, "Invalid connect timeout value '" + connectTimeout + "'. Must be a non-negative integer.");
checkArgument(readTimeout == null || readTimeout >= 0, "Invalid read timeout value '" + readTimeout + "'. Must be a non-negative integer.");

this.url = url;
this.connectTimeout = connectTimeout;
this.readTimeout = readTimeout;
}

/**
Expand Down Expand Up @@ -68,7 +86,14 @@ static URL urlForDomain(String domain) {

private Map<String, Object> getJwks() throws SigningKeyNotFoundException {
try {
final InputStream inputStream = this.url.openStream();
final URLConnection c = this.url.openConnection();
if(connectTimeout != null) {
c.setConnectTimeout(connectTimeout);
}
if(readTimeout != null) {
c.setReadTimeout(readTimeout);
}
final InputStream inputStream = c.getInputStream();
final JsonFactory factory = new JsonFactory();
final JsonParser parser = factory.createParser(inputStream);
final TypeReference<Map<String, Object>> typeReference = new TypeReference<Map<String, Object>>() {};
Expand Down
72 changes: 72 additions & 0 deletions src/test/java/com/auth0/jwk/UrlJwkProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;

import java.io.IOException;
import java.lang.ref.WeakReference;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;

import static com.auth0.jwk.UrlJwkProvider.WELL_KNOWN_JWKS_PATH;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.*;
import static org.hamcrest.Matchers.*;
import static org.mockito.Mockito.mock;

public class UrlJwkProviderTest {

Expand Down Expand Up @@ -128,4 +139,65 @@ public void shouldFailOnInvalidProtocol() {
String domainWithInvalidProtocol = "httptest://samples.auth0.com";
new UrlJwkProvider(domainWithInvalidProtocol);
}

@Test
public void shouldFailWithNegativeConnectTimeout() throws MalformedURLException {
expectedException.expect(IllegalArgumentException.class);
new UrlJwkProvider(new URL("https://localhost"), -1, null);
}

@Test
public void shouldFailWithNegativeReadTimeout() throws MalformedURLException {
expectedException.expect(IllegalArgumentException.class);
new UrlJwkProvider(new URL("https://localhost"), null, -1);
}

private static class MockURLStreamHandlerFactory implements URLStreamHandlerFactory {

// The weak reference is just a safeguard against objects not being released
// for garbage collection
private final WeakReference<URLConnection> value;

public MockURLStreamHandlerFactory(URLConnection urlConnection) {
this.value = new WeakReference<URLConnection>(urlConnection);
}

@Override
public URLStreamHandler createURLStreamHandler(String protocol) {
return "mock".equals(protocol) ? new URLStreamHandler() {
protected URLConnection openConnection(URL url) throws IOException {
try {
return value.get();
} finally {
value.clear();
}
}
} : null;
}
}

@Test
public void shouldConfigureURLConnectionTimeouts() throws Exception {
URLConnection urlConnection = mock(URLConnection.class);

// Although somewhat of a hack, this approach gets the job done - this method can
// only be called once per virtual machine, but that is sufficient for now.
URL.setURLStreamHandlerFactory(new MockURLStreamHandlerFactory(urlConnection));
when(urlConnection.getInputStream()).thenReturn(getClass().getResourceAsStream("/jwks.json"));

int connectTimeout = 10000;
int readTimeout = 15000;

UrlJwkProvider urlJwkProvider = new UrlJwkProvider(new URL("mock://localhost"), connectTimeout, readTimeout);
Jwk jwk = urlJwkProvider.get("NkJCQzIyQzRBMEU4NjhGNUU4MzU4RkY0M0ZDQzkwOUQ0Q0VGNUMwQg");
assertNotNull(jwk);

ArgumentCaptor<Integer> connectTimeoutCaptor = ArgumentCaptor.forClass(Integer.class);
verify(urlConnection).setConnectTimeout(connectTimeoutCaptor.capture());
assertThat(connectTimeoutCaptor.getValue(),is(connectTimeout));

ArgumentCaptor<Integer> readTimeoutCaptor = ArgumentCaptor.forClass(Integer.class);
verify(urlConnection).setReadTimeout(readTimeoutCaptor.capture());
assertThat(readTimeoutCaptor.getValue(),is(readTimeout));
}
}

0 comments on commit e7fbffa

Please sign in to comment.