From 629b0b680d26452b3bf51f8777fe2e2e027da827 Mon Sep 17 00:00:00 2001 From: Szymon Homa Date: Tue, 16 Mar 2021 22:57:18 +0100 Subject: [PATCH 1/2] Extract class MockRedirectHandler --- .../auth/external/MockRedirectHandler.java | 34 +++++++++++++++++++ .../external/TestExternalAuthentication.java | 18 ---------- 2 files changed, 34 insertions(+), 18 deletions(-) create mode 100644 client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java b/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java new file mode 100644 index 0000000000000..7aa3bebc06bb8 --- /dev/null +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.auth.external; + +import java.net.URI; + +public class MockRedirectHandler + implements RedirectHandler +{ + private URI redirectedTo; + + @Override + public void redirectTo(URI uri) + throws RedirectException + { + redirectedTo = uri; + } + + public URI redirectedTo() + { + return redirectedTo; + } +} diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java index 279d1966b4c7f..dce341d203923 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthentication.java @@ -125,22 +125,4 @@ public void testObtainTokenWhenNoRedirectUriHasBeenProvided() assertThat(redirectHandler.redirectedTo()).isNull(); assertThat(token).map(Token::token).hasValue(AUTH_TOKEN); } - - private static class MockRedirectHandler - implements RedirectHandler - { - private URI redirectedTo; - - @Override - public void redirectTo(URI uri) - throws RedirectException - { - redirectedTo = uri; - } - - public URI redirectedTo() - { - return redirectedTo; - } - } } From 7162450321cb30860dbd40177a92b0039878934f Mon Sep 17 00:00:00 2001 From: Szymon Homa Date: Tue, 16 Mar 2021 23:30:27 +0100 Subject: [PATCH 2/2] Add cached token option to jdbc externalAuthentication This change allows sharing external authentication tokens between different Connections. Each time when a new token is required, first Connection that needs it, will handle obtaining a new token when all the other Connections wait for this operation to finish. Token is kept in memmory, guarded by ReadWriteLock. To enable token cache use externalAuthenticationTokenCache=MEMORY Default value for externalAuthenticationTokenCache is NONE. --- .../main/java/io/trino/cli/QueryRunner.java | 2 + client/trino-client/pom.xml | 6 + .../auth/external/ExternalAuthenticator.java | 32 +-- .../client/auth/external/KnownToken.java | 34 +++ .../client/auth/external/LocalKnownToken.java | 46 ++++ .../auth/external/MemoryCachedKnownToken.java | 83 +++++++ .../auth/external/MockRedirectHandler.java | 25 +++ .../client/auth/external/MockTokenPoller.java | 12 +- .../external/TestExternalAuthenticator.java | 208 +++++++++++++++++- .../io/trino/jdbc/ConnectionProperties.java | 11 + .../java/io/trino/jdbc/KnownTokenCache.java | 36 +++ .../java/io/trino/jdbc/TrinoDriverUri.java | 5 +- 12 files changed, 475 insertions(+), 25 deletions(-) create mode 100644 client/trino-client/src/main/java/io/trino/client/auth/external/KnownToken.java create mode 100644 client/trino-client/src/main/java/io/trino/client/auth/external/LocalKnownToken.java create mode 100644 client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java create mode 100644 client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java diff --git a/client/trino-cli/src/main/java/io/trino/cli/QueryRunner.java b/client/trino-cli/src/main/java/io/trino/cli/QueryRunner.java index 4de030dca5a7e..6e650b4e235d3 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/QueryRunner.java +++ b/client/trino-cli/src/main/java/io/trino/cli/QueryRunner.java @@ -20,6 +20,7 @@ import io.trino.client.StatementClient; import io.trino.client.auth.external.ExternalAuthenticator; import io.trino.client.auth.external.HttpTokenPoller; +import io.trino.client.auth.external.KnownToken; import io.trino.client.auth.external.RedirectHandler; import io.trino.client.auth.external.TokenPoller; import okhttp3.OkHttpClient; @@ -195,6 +196,7 @@ private static void setupExternalAuth( ExternalAuthenticator authenticator = new ExternalAuthenticator( redirectHandler, poller, + KnownToken.local(), Duration.ofMinutes(10)); builder.authenticator(authenticator); diff --git a/client/trino-client/pom.xml b/client/trino-client/pom.xml index 8c4e4de10e52f..9797fd37f5684 100644 --- a/client/trino-client/pom.xml +++ b/client/trino-client/pom.xml @@ -81,6 +81,12 @@ test + + io.airlift + concurrent + test + + com.squareup.okhttp3 mockwebserver diff --git a/client/trino-client/src/main/java/io/trino/client/auth/external/ExternalAuthenticator.java b/client/trino-client/src/main/java/io/trino/client/auth/external/ExternalAuthenticator.java index 2269cf87c2ace..e91ff6572dcd0 100644 --- a/client/trino-client/src/main/java/io/trino/client/auth/external/ExternalAuthenticator.java +++ b/client/trino-client/src/main/java/io/trino/client/auth/external/ExternalAuthenticator.java @@ -44,12 +44,13 @@ public class ExternalAuthenticator private final TokenPoller tokenPoller; private final RedirectHandler redirectHandler; private final Duration timeout; - private Token knownToken; + private final KnownToken knownToken; - public ExternalAuthenticator(RedirectHandler redirect, TokenPoller tokenPoller, Duration timeout) + public ExternalAuthenticator(RedirectHandler redirect, TokenPoller tokenPoller, KnownToken knownToken, Duration timeout) { this.tokenPoller = requireNonNull(tokenPoller, "tokenPoller is null"); this.redirectHandler = requireNonNull(redirect, "redirect is null"); + this.knownToken = requireNonNull(knownToken, "knownToken is null"); this.timeout = requireNonNull(timeout, "timeout is null"); } @@ -57,28 +58,27 @@ public ExternalAuthenticator(RedirectHandler redirect, TokenPoller tokenPoller, @Override public Request authenticate(Route route, Response response) { - knownToken = null; - - Optional authentication = toAuthentication(response); - if (!authentication.isPresent()) { - return null; - } + knownToken.setupToken(() -> { + Optional authentication = toAuthentication(response); + if (!authentication.isPresent()) { + return Optional.empty(); + } - Optional token = authentication.get().obtainToken(timeout, redirectHandler, tokenPoller); - if (!token.isPresent()) { - return null; - } + return authentication.get().obtainToken(timeout, redirectHandler, tokenPoller); + }); - knownToken = token.get(); - return withBearerToken(response.request(), knownToken); + return knownToken.getToken() + .map(token -> withBearerToken(response.request(), token)) + .orElse(null); } @Override public Response intercept(Chain chain) throws IOException { - if (knownToken != null) { - return chain.proceed(withBearerToken(chain.request(), knownToken)); + Optional token = knownToken.getToken(); + if (token.isPresent()) { + return chain.proceed(withBearerToken(chain.request(), token.get())); } return chain.proceed(chain.request()); diff --git a/client/trino-client/src/main/java/io/trino/client/auth/external/KnownToken.java b/client/trino-client/src/main/java/io/trino/client/auth/external/KnownToken.java new file mode 100644 index 0000000000000..240af90757639 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/auth/external/KnownToken.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.auth.external; + +import java.util.Optional; +import java.util.function.Supplier; + +public interface KnownToken +{ + Optional getToken(); + + void setupToken(Supplier> tokenSource); + + static KnownToken local() + { + return new LocalKnownToken(); + } + + static KnownToken memoryCached() + { + return MemoryCachedKnownToken.INSTANCE; + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/auth/external/LocalKnownToken.java b/client/trino-client/src/main/java/io/trino/client/auth/external/LocalKnownToken.java new file mode 100644 index 0000000000000..40b984a0764bd --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/auth/external/LocalKnownToken.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.auth.external; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.Optional; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +/** + * LocalKnownToken class keeps the token on its field + * and it's designed to use it in fully serialized manner. + */ +@NotThreadSafe +class LocalKnownToken + implements KnownToken +{ + private Optional knownToken = Optional.empty(); + + @Override + public Optional getToken() + { + return knownToken; + } + + @Override + public void setupToken(Supplier> tokenSource) + { + requireNonNull(tokenSource, "tokenSource is null"); + + knownToken = tokenSource.get(); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java b/client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java new file mode 100644 index 0000000000000..e8513e4bd87fc --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/auth/external/MemoryCachedKnownToken.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.auth.external; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Optional; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Supplier; + +/** + * This KnownToken instance forces all Connections to reuse same token. + * Every time an existing token is considered to be invalid each Connection + * will try to obtain a new token, but only the first one will actually do the job, + * where every other connection will be waiting on readLock + * until obtaining new token finishes. + *

+ * In general the game is to reuse same token and obtain it only once, no matter how + * many Connections will be actively using it. It's very important as obtaining the new token + * will take minutes, as it mostly requires user thinking time. + */ +@ThreadSafe +class MemoryCachedKnownToken + implements KnownToken +{ + public static final MemoryCachedKnownToken INSTANCE = new MemoryCachedKnownToken(); + + private final ReadWriteLock lock = new ReentrantReadWriteLock(); + private final Lock readLock = lock.readLock(); + private final Lock writeLock = lock.writeLock(); + private Optional knownToken = Optional.empty(); + + private MemoryCachedKnownToken() + { + } + + @Override + public Optional getToken() + { + try { + readLock.lockInterruptibly(); + return knownToken; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + finally { + readLock.unlock(); + } + } + + @Override + public void setupToken(Supplier> tokenSource) + { + // Try to lock and generate new token. If some other thread (Connection) has + // already obtained writeLock and is generating new token, then skipp this + // to block on getToken() + if (writeLock.tryLock()) { + try { + // Clear knownToken before obtaining new token, as it might fail leaving old invalid token. + knownToken = Optional.empty(); + knownToken = tokenSource.get(); + } + finally { + writeLock.unlock(); + } + } + } +} diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java b/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java index 7aa3bebc06bb8..cf671ea391c9b 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/MockRedirectHandler.java @@ -14,21 +14,46 @@ package io.trino.client.auth.external; import java.net.URI; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; public class MockRedirectHandler implements RedirectHandler { private URI redirectedTo; + private AtomicInteger redirectionCount = new AtomicInteger(0); + private Duration redirectTime; @Override public void redirectTo(URI uri) throws RedirectException { redirectedTo = uri; + redirectionCount.incrementAndGet(); + try { + if (redirectTime != null) { + Thread.sleep(redirectTime.toMillis()); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } } public URI redirectedTo() { return redirectedTo; } + + public int getRedirectionCount() + { + return redirectionCount.get(); + } + + public MockRedirectHandler sleepOnRedirect(Duration redirectTime) + { + this.redirectTime = redirectTime; + return this; + } } diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java b/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java index fae34ea9c031a..d07205f38636c 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/MockTokenPoller.java @@ -17,21 +17,21 @@ import java.net.URI; import java.time.Duration; -import java.util.ArrayDeque; -import java.util.HashMap; import java.util.Map; -import java.util.Queue; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingDeque; public final class MockTokenPoller implements TokenPoller { - private final Map> results = new HashMap<>(); + private final Map> results = new ConcurrentHashMap<>(); public MockTokenPoller withResult(URI tokenUri, TokenPollResult result) { results.compute(tokenUri, (uri, queue) -> { if (queue == null) { - return new ArrayDeque<>(ImmutableList.of(result)); + return new LinkedBlockingDeque<>(ImmutableList.of(result)); } queue.add(result); return queue; @@ -42,7 +42,7 @@ public MockTokenPoller withResult(URI tokenUri, TokenPollResult result) @Override public TokenPollResult pollForToken(URI tokenUri, Duration ignored) { - Queue queue = results.get(tokenUri); + BlockingDeque queue = results.get(tokenUri); if (queue == null) { throw new IllegalArgumentException("Unknown token URI: " + tokenUri); } diff --git a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java index 729ffe6fa0230..c9ea6fc73199e 100644 --- a/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java +++ b/client/trino-client/src/test/java/io/trino/client/auth/external/TestExternalAuthenticator.java @@ -13,31 +13,55 @@ */ package io.trino.client.auth.external; +import com.google.common.collect.ImmutableList; import io.trino.client.ClientException; import okhttp3.HttpUrl; import okhttp3.Protocol; import okhttp3.Request; import okhttp3.Response; +import org.assertj.core.api.ListAssert; +import org.assertj.core.api.ThrowableAssert; +import org.testng.annotations.AfterClass; import org.testng.annotations.Test; import java.net.URI; import java.net.URISyntaxException; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.stream.Stream; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.net.HttpHeaders.AUTHORIZATION; import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.client.auth.external.ExternalAuthenticator.TOKEN_URI_FIELD; import static io.trino.client.auth.external.ExternalAuthenticator.toAuthentication; import static io.trino.client.auth.external.TokenPollResult.successful; import static java.lang.String.format; import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import static java.net.URI.create; +import static java.util.concurrent.Executors.newCachedThreadPool; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestExternalAuthenticator { + private static final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(TestExternalAuthenticator.class.getName() + "-%d")); + + @AfterClass(alwaysRun = true) + public void shutDownThreadPool() + { + executor.shutdownNow(); + } + @Test public void testChallengeWithOnlyTokenServerUri() { @@ -110,7 +134,7 @@ public void testAuthentication() { MockTokenPoller tokenPoller = new MockTokenPoller() .withResult(URI.create("http://token.uri"), successful(new Token("valid-token"))); - ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, Duration.ofSeconds(1)); + ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, KnownToken.local(), Duration.ofSeconds(1)); Request authenticated = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"")); @@ -125,7 +149,7 @@ public void testReAuthenticationAfterRejectingToken() MockTokenPoller tokenPoller = new MockTokenPoller() .withResult(URI.create("http://token.uri"), successful(new Token("first-token"))) .withResult(URI.create("http://token.uri"), successful(new Token("second-token"))); - ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, Duration.ofSeconds(1)); + ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, KnownToken.local(), Duration.ofSeconds(1)); Request request = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"")); Request reAuthenticated = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"", request)); @@ -134,6 +158,140 @@ public void testReAuthenticationAfterRejectingToken() .containsExactly("Bearer second-token"); } + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithLocallyStoredToken() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-1"))) + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-2"))) + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-3"))) + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-4"))); + MockRedirectHandler redirectHandler = new MockRedirectHandler(); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.local(), Duration.ofSeconds(1)); + List> requests = times( + 4, + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests() + .extracting(Request::headers) + .extracting(headers -> headers.get(AUTHORIZATION)) + .contains("Bearer valid-token-1", "Bearer valid-token-2", "Bearer valid-token-3", "Bearer valid-token-4"); + assertion.assertThatNoExceptionsHasBeenThrown(); + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(4); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedToken() + { + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(this.getClass().getName() + "%n")); + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token"))); + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofMillis(10)); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1)); + List> requests = times( + 4, + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests() + .extracting(Request::headers) + .extracting(headers -> headers.get(AUTHORIZATION)) + .containsOnly("Bearer valid-token"); + assertion.assertThatNoExceptionsHasBeenThrown(); + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateFails() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), TokenPollResult.successful(new Token("first-token"))) + .withResult(URI.create("http://token.uri"), TokenPollResult.failed("external authentication error")); + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofMillis(10)); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1)); + Request firstRequest = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")); + + List> requests = times( + 4, + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"", firstRequest))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests().containsExactly(null, null, null); + assertion.firstException().hasMessage("external authentication error") + .isInstanceOf(ClientException.class); + + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(2); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateTimesOut() + { + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofMillis(5)); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, (uri, duration) -> TokenPollResult.pending(uri), KnownToken.memoryCached(), Duration.ofMillis(1)); + List> requests = times( + 4, + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests() + .containsExactly(null, null, null, null); + assertion.assertThatNoExceptionsHasBeenThrown(); + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateIsInterrupted() + throws Exception + { + ExecutorService interruptableThreadPool = newCachedThreadPool(daemonThreadsNamed(this.getClass().getName() + "-interruptable-%d")); + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofMinutes(1)); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, (uri, duration) -> TokenPollResult.pending(uri), KnownToken.memoryCached(), Duration.ofMillis(1)); + Future interruptedAuthentication = interruptableThreadPool.submit( + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))); + Thread.sleep(100); //It's here to make sure that authentication will start before the other threads. + List> requests = times( + 2, + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + Thread.sleep(100); + interruptableThreadPool.shutdownNow(); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(ImmutableList.>builder() + .addAll(requests) + .add(interruptedAuthentication) + .build()); + assertion.requests().containsExactly(null, null); + assertion.firstException().hasRootCauseInstanceOf(InterruptedException.class); + + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); + } + + private static Stream> times(int times, Callable request) + { + return Stream.generate(() -> request) + .limit(times); + } + private static Optional buildAuthentication(String challengeHeader) { return toAuthentication(getUnauthorizedResponse(challengeHeader)); @@ -157,4 +315,50 @@ private static Response getUnauthorizedResponse(String challengeHeader, Request .header(WWW_AUTHENTICATE, challengeHeader) .build(); } + + static class ConcurrentRequestAssertion + { + private final List exceptions = new ArrayList<>(); + private final List requests = new ArrayList<>(); + + public ConcurrentRequestAssertion(List> requests) + { + for (Future request : requests) { + try { + this.requests.add(request.get()); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + catch (CancellationException ex) { + exceptions.add(ex); + } + catch (ExecutionException ex) { + checkState(ex.getCause() != null, "Missing cause on ExecutionException " + ex.getMessage()); + + exceptions.add(ex.getCause()); + } + } + } + + ThrowableAssert firstException() + { + return exceptions.stream() + .findFirst() + .map(ThrowableAssert::new) + .orElseGet(() -> new ThrowableAssert(() -> null)); + } + + void assertThatNoExceptionsHasBeenThrown() + { + assertThat(exceptions) + .isEmpty(); + } + + ListAssert requests() + { + return assertThat(requests); + } + } } diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java index c3db9662374b3..2091719fa0d91 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/ConnectionProperties.java @@ -73,6 +73,7 @@ enum SslVerificationMode public static final ConnectionProperty ACCESS_TOKEN = new AccessToken(); public static final ConnectionProperty EXTERNAL_AUTHENTICATION = new ExternalAuthentication(); public static final ConnectionProperty EXTERNAL_AUTHENTICATION_TIMEOUT = new ExternalAuthenticationTimeout(); + public static final ConnectionProperty EXTERNAL_AUTHENTICATION_TOKEN_CACHE = new ExternalAuthenticationTokenCache(); public static final ConnectionProperty> EXTRA_CREDENTIALS = new ExtraCredentials(); public static final ConnectionProperty CLIENT_INFO = new ClientInfo(); public static final ConnectionProperty CLIENT_TAGS = new ClientTags(); @@ -113,6 +114,7 @@ enum SslVerificationMode .add(SOURCE) .add(EXTERNAL_AUTHENTICATION) .add(EXTERNAL_AUTHENTICATION_TIMEOUT) + .add(EXTERNAL_AUTHENTICATION_TOKEN_CACHE) .build(); private static final Map> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream() @@ -461,6 +463,15 @@ public ExternalAuthenticationTimeout() } } + private static class ExternalAuthenticationTokenCache + extends AbstractConnectionProperty + { + public ExternalAuthenticationTokenCache() + { + super("externalAuthenticationTokenCache", Optional.of(KnownTokenCache.NONE.name()), NOT_REQUIRED, ALLOWED, KnownTokenCache::valueOf); + } + } + private static class ExtraCredentials extends AbstractConnectionProperty> { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java new file mode 100644 index 0000000000000..6c3dde57d8c73 --- /dev/null +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/KnownTokenCache.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import io.trino.client.auth.external.KnownToken; + +public enum KnownTokenCache +{ + NONE { + @Override + KnownToken create() + { + return KnownToken.local(); + } + }, + MEMORY { + @Override + KnownToken create() + { + return KnownToken.memoryCached(); + } + }; + + abstract KnownToken create(); +} diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java index 9a2af8cb55d06..3f7df397ece40 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoDriverUri.java @@ -58,6 +58,7 @@ import static io.trino.jdbc.ConnectionProperties.DISABLE_COMPRESSION; import static io.trino.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION; import static io.trino.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_TIMEOUT; +import static io.trino.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_TOKEN_CACHE; import static io.trino.jdbc.ConnectionProperties.EXTRA_CREDENTIALS; import static io.trino.jdbc.ConnectionProperties.HTTP_PROXY; import static io.trino.jdbc.ConnectionProperties.KERBEROS_CONFIG_PATH; @@ -310,7 +311,9 @@ public void setupClient(OkHttpClient.Builder builder) .map(value -> Duration.ofMillis(value.toMillis())) .orElse(Duration.ofMinutes(2)); - ExternalAuthenticator authenticator = new ExternalAuthenticator(REDIRECT_HANDLER.get(), poller, timeout); + KnownTokenCache knownTokenCache = EXTERNAL_AUTHENTICATION_TOKEN_CACHE.getValue(properties).get(); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(REDIRECT_HANDLER.get(), poller, knownTokenCache.create(), timeout); builder.authenticator(authenticator); builder.addInterceptor(authenticator);