From d6393eeb576b1cd6898a6ca1874bb9157ec210db Mon Sep 17 00:00:00 2001 From: SimbaGithub <48035983+SimbaGithub@users.noreply.github.com> Date: Mon, 14 Mar 2022 11:46:06 -0700 Subject: [PATCH] SNOW 500881 JWT expiration issue (#715) *SNOW-500881 JWT expiration fix Co-authored-by: sfc-gh-hchaturvedi Co-authored-by: Lorna Barber --- .../net/snowflake/client/core/HttpUtil.java | 74 +++++++++++- .../net/snowflake/client/core/Incident.java | 5 +- .../snowflake/client/core/SFBaseSession.java | 8 +- .../snowflake/client/core/SFLoginInput.java | 12 +- .../net/snowflake/client/core/SFSession.java | 29 ++++- .../snowflake/client/core/SessionUtil.java | 106 ++++++++++++++++-- .../core/SessionUtilExternalBrowser.java | 9 +- .../client/core/SessionUtilKeyPair.java | 15 ++- .../net/snowflake/client/core/StmtUtil.java | 18 ++- .../client/jdbc/ChunkDownloadContext.java | 14 +++ .../jdbc/DefaultResultStreamProvider.java | 3 + .../net/snowflake/client/jdbc/ErrorCode.java | 3 +- .../snowflake/client/jdbc/RestRequest.java | 61 +++++++++- .../client/jdbc/SnowflakeChunkDownloader.java | 16 ++- .../SnowflakeResultSetSerializableV1.java | 15 ++- .../client/jdbc/SnowflakeSQLException.java | 29 ++++- .../cloud/storage/SnowflakeGCSClient.java | 11 +- .../jdbc/telemetry/TelemetryClient.java | 14 ++- .../core/SessionUtilExternalBrowserTest.java | 8 +- .../client/core/SessionUtilLatestIT.java | 90 +++++++++++++++ .../client/core/SnowflakeMFACacheTest.java | 16 ++- .../client/jdbc/ConnectionLatestIT.java | 4 +- .../client/jdbc/MockConnectionTest.java | 14 +++ .../client/jdbc/RestRequestTest.java | 51 ++++++++- .../client/jdbc/SSOConnectionTest.java | 9 +- .../client/jdbc/ServiceNameTest.java | 8 +- 26 files changed, 596 insertions(+), 46 deletions(-) create mode 100644 src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java diff --git a/src/main/java/net/snowflake/client/core/HttpUtil.java b/src/main/java/net/snowflake/client/core/HttpUtil.java index 062774b99..4e302b359 100644 --- a/src/main/java/net/snowflake/client/core/HttpUtil.java +++ b/src/main/java/net/snowflake/client/core/HttpUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -471,6 +471,24 @@ public static RequestConfig getDefaultRequestConfigWithSocketTimeout( .build(); } + /** + * Return a request configuration inheriting from the default request configuration of the shared + * HttpClient with a different socket and connect timeout. + * + * @param requestSocketAndConnectTimeout - custom socket and connect timeout in milli-seconds + * @param withoutCookies - whether this request should ignore cookies or not + * @return RequestConfig object + */ + public static RequestConfig getDefaultRequestConfigWithSocketAndConnectTimeout( + int requestSocketAndConnectTimeout, boolean withoutCookies) { + final String cookieSpec = withoutCookies ? IGNORE_COOKIES : DEFAULT; + return RequestConfig.copy(DefaultRequestConfig) + .setSocketTimeout(requestSocketAndConnectTimeout) + .setConnectTimeout(requestSocketAndConnectTimeout) + .setCookieSpec(cookieSpec) + .build(); + } + /** * Return a request configuration inheriting from the default request configuration of the shared * HttpClient with the coopkie spec set to ignore. @@ -517,6 +535,9 @@ public static boolean isSocksProxyDisabled() { * * @param httpRequest HttpRequestBase * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount retry count for the request * @param injectSocketTimeout injecting socket timeout * @param canceling canceling? * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient @@ -527,6 +548,9 @@ public static boolean isSocksProxyDisabled() { static String executeRequestWithoutCookies( HttpRequestBase httpRequest, int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, int injectSocketTimeout, AtomicBoolean canceling, HttpClientSettingsKey ocspAndProxyKey) @@ -534,6 +558,9 @@ static String executeRequestWithoutCookies( return executeRequestInternal( httpRequest, retryTimeout, + authTimeout, + socketTimeout, + retryCount, injectSocketTimeout, canceling, true, // no cookie @@ -548,17 +575,28 @@ static String executeRequestWithoutCookies( * * @param httpRequest HttpRequestBase * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount retry count for the request * @param ocspAndProxyKey OCSP mode and proxy settings for httpclient * @return response * @throws SnowflakeSQLException if Snowflake error occurs * @throws IOException raises if a general IO error occurs */ public static String executeGeneralRequest( - HttpRequestBase httpRequest, int retryTimeout, HttpClientSettingsKey ocspAndProxyKey) + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + HttpClientSettingsKey ocspAndProxyKey) throws SnowflakeSQLException, IOException { return executeRequest( httpRequest, retryTimeout, + authTimeout, + socketTimeout, + retryCount, 0, // no inject socket timeout null, // no canceling false, // no retry parameter @@ -571,17 +609,28 @@ public static String executeGeneralRequest( * * @param httpRequest HttpRequestBase * @param retryTimeout retry timeout + * @param authTimeout authenticator specific timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount retry count for the request * @param httpClient client object used to communicate with other machine * @return response * @throws SnowflakeSQLException if Snowflake error occurs * @throws IOException raises if a general IO error occurs */ public static String executeGeneralRequest( - HttpRequestBase httpRequest, int retryTimeout, CloseableHttpClient httpClient) + HttpRequestBase httpRequest, + int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, + CloseableHttpClient httpClient) throws SnowflakeSQLException, IOException { return executeRequestInternal( httpRequest, retryTimeout, + authTimeout, + socketTimeout, + retryCount, 0, // no inject socket timeout null, // no canceling false, // with cookie @@ -596,6 +645,9 @@ public static String executeGeneralRequest( * * @param httpRequest HttpRequestBase * @param retryTimeout retry timeout + * @param authTimeout authenticator timeout + * @param socketTimeout socket timeout (in ms) + * @param retryCount retry count for the request * @param injectSocketTimeout injecting socket timeout * @param canceling canceling? * @param includeRetryParameters whether to include retry parameters in retried requests @@ -608,6 +660,9 @@ public static String executeGeneralRequest( public static String executeRequest( HttpRequestBase httpRequest, int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, int injectSocketTimeout, AtomicBoolean canceling, boolean includeRetryParameters, @@ -617,6 +672,9 @@ public static String executeRequest( return executeRequestInternal( httpRequest, retryTimeout, + authTimeout, + socketTimeout, + retryCount, injectSocketTimeout, canceling, false, // with cookie (do we need cookie?) @@ -635,6 +693,9 @@ public static String executeRequest( * * @param httpRequest request object contains all the information * @param retryTimeout retry timeout (in seconds) + * @param authTimeout authenticator specific timeout (in seconds) + * @param socketTimeout socket timeout (in ms) + * @param retryCount retry count for the request * @param injectSocketTimeout simulate socket timeout * @param canceling canceling flag * @param withoutCookies whether this request should ignore cookies @@ -649,6 +710,9 @@ public static String executeRequest( private static String executeRequestInternal( HttpRequestBase httpRequest, int retryTimeout, + int authTimeout, + int socketTimeout, + int retryCount, int injectSocketTimeout, AtomicBoolean canceling, boolean withoutCookies, @@ -667,12 +731,16 @@ private static String executeRequestInternal( String theString; StringWriter writer = null; CloseableHttpResponse response = null; + try { response = RestRequest.execute( httpClient, httpRequest, retryTimeout, + authTimeout, + socketTimeout, + retryCount, injectSocketTimeout, canceling, withoutCookies, diff --git a/src/main/java/net/snowflake/client/core/Incident.java b/src/main/java/net/snowflake/client/core/Incident.java index 748ea7696..622965bcb 100644 --- a/src/main/java/net/snowflake/client/core/Incident.java +++ b/src/main/java/net/snowflake/client/core/Incident.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2019-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -232,6 +232,9 @@ public void flush() { HttpUtil.executeGeneralRequest( postRequest, 1000, + 0, + 0, + 0, ocspAndProxyKey != null ? ocspAndProxyKey : new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); diff --git a/src/main/java/net/snowflake/client/core/SFBaseSession.java b/src/main/java/net/snowflake/client/core/SFBaseSession.java index 951328734..b4cd097fe 100644 --- a/src/main/java/net/snowflake/client/core/SFBaseSession.java +++ b/src/main/java/net/snowflake/client/core/SFBaseSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -765,5 +765,11 @@ public SFConnectionHandler getSfConnectionHandler() { public abstract int getNetworkTimeoutInMilli(); + public abstract int getAuthTimeout(); + public abstract SnowflakeConnectString getSnowflakeConnectionString(); + + public abstract int getHttpClientConnectionTimeout(); + + public abstract int getHttpClientSocketTimeout(); } diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index ea550c7a7..f448d086a 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -25,6 +25,7 @@ public class SFLoginInput { private String oktaUserName; private String accountName; private int loginTimeout = -1; // default is invalid + private int authTimeout = 0; private String userName; private String password; private boolean passcodeInPassword; @@ -139,6 +140,15 @@ SFLoginInput setLoginTimeout(int loginTimeout) { return this; } + int getAuthTimeout() { + return authTimeout; + } + + SFLoginInput setAuthTimeout(int authTimeout) { + this.authTimeout = authTimeout; + return this; + } + public String getUserName() { return userName; } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index cb0d8ca65..c831abebf 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -80,6 +80,7 @@ public class SFSession extends SFBaseSession { */ private int networkTimeoutInMilli = 0; // in milliseconds + private int authTimeout = 0; private boolean enableCombineDescribe = false; private int httpClientConnectionTimeout = 60000; // milliseconds private int httpClientSocketTimeout = DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT; // milliseconds @@ -166,7 +167,9 @@ public QueryStatus getQueryStatus(String queryID) throws SQLException { try { get.setHeader("Content-type", "application/json"); get.setHeader("Authorization", "Snowflake Token=\"" + this.sessionToken + "\""); - response = HttpUtil.executeGeneralRequest(get, loginTimeout, getHttpClientKey()); + response = + HttpUtil.executeGeneralRequest( + get, loginTimeout, authTimeout, httpClientSocketTimeout, 0, getHttpClientKey()); jsonNode = OBJECT_MAPPER.readTree(response); } catch (Exception e) { throw new SnowflakeSQLLoggedException( @@ -432,6 +435,7 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setOKTAUserName((String) connectionPropertiesMap.get(SFSessionProperty.OKTA_USERNAME)) .setAccountName((String) connectionPropertiesMap.get(SFSessionProperty.ACCOUNT)) .setLoginTimeout(loginTimeout) + .setAuthTimeout(authTimeout) .setUserName((String) connectionPropertiesMap.get(SFSessionProperty.USER)) .setPassword((String) connectionPropertiesMap.get(SFSessionProperty.PASSWORD)) .setToken((String) connectionPropertiesMap.get(SFSessionProperty.TOKEN)) @@ -457,6 +461,7 @@ public synchronized void open() throws SFException, SnowflakeSQLException { SessionUtil.openSession(loginInput, connectionPropertiesMap, tracingLevel.toString()); isClosed = false; + authTimeout = loginInput.getAuthTimeout(); sessionToken = loginOutput.getSessionToken(); masterToken = loginOutput.getMasterToken(); idToken = loginOutput.getIdToken(); @@ -716,7 +721,13 @@ protected void heartbeat() throws SFException, SQLException { // per https://support-snowflake.zendesk.com/agent/tickets/6629 int SF_HEARTBEAT_TIMEOUT = 300; String theResponse = - HttpUtil.executeGeneralRequest(postRequest, SF_HEARTBEAT_TIMEOUT, getHttpClientKey()); + HttpUtil.executeGeneralRequest( + postRequest, + SF_HEARTBEAT_TIMEOUT, + authTimeout, + httpClientSocketTimeout, + 0, + getHttpClientKey()); JsonNode rootNode; @@ -789,6 +800,18 @@ public int getNetworkTimeoutInMilli() { return networkTimeoutInMilli; } + public int getAuthTimeout() { + return authTimeout; + } + + public int getHttpClientSocketTimeout() { + return httpClientSocketTimeout; + } + + public int getHttpClientConnectionTimeout() { + return httpClientConnectionTimeout; + } + public boolean isClosed() { return isClosed; } diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index 5435acae7..94dd19fc3 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2020 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -346,8 +346,8 @@ private static SFLoginOutput newSession( Map commonParams; try { - uriBuilder = new URIBuilder(loginInput.getServerUrl()); + uriBuilder = new URIBuilder(loginInput.getServerUrl()); // add database name and schema name as query parameters if (loginInput.getDatabaseName() != null) { uriBuilder.addParameter(SF_QUERY_DATABASE, loginInput.getDatabaseName()); @@ -388,6 +388,7 @@ private static SFLoginOutput newSession( loginInput.getUserName()); loginInput.setToken(s.issueJwtToken()); + loginInput.setAuthTimeout(SessionUtilKeyPair.getTimeout()); } uriBuilder.addParameter(SFSession.SF_QUERY_REQUEST_ID, UUID.randomUUID().toString()); @@ -584,9 +585,73 @@ private static SFLoginOutput newSession( setServiceNameHeader(loginInput, postRequest); - String theString = - HttpUtil.executeGeneralRequest( - postRequest, loginInput.getLoginTimeout(), loginInput.getHttpClientSettingsKey()); + String theString = null; + int leftRetryTimeout = loginInput.getLoginTimeout(); + int leftsocketTimeout = loginInput.getSocketTimeout(); + int retryCount = 0; + + while (true) { + try { + theString = + HttpUtil.executeGeneralRequest( + postRequest, + leftRetryTimeout, + loginInput.getAuthTimeout(), + leftsocketTimeout, + retryCount, + loginInput.getHttpClientSettingsKey()); + } catch (SnowflakeSQLException ex) { + if (ex.getErrorCode() == ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT.getMessageCode()) { + if (authenticatorType == ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT) { + SessionUtilKeyPair s = + new SessionUtilKeyPair( + loginInput.getPrivateKey(), + loginInput.getPrivateKeyFile(), + loginInput.getPrivateKeyFilePwd(), + loginInput.getAccountName(), + loginInput.getUserName()); + + data.put(ClientAuthnParameter.TOKEN.name(), s.issueJwtToken()); + + long elapsedSeconds = ex.getElapsedSeconds(); + + if (loginInput.getLoginTimeout() > 0) { + if (leftRetryTimeout > elapsedSeconds) { + leftRetryTimeout -= elapsedSeconds; + } else { + leftRetryTimeout = 1; + } + } + + // In RestRequest.execute(), socket timeout is replaced with auth timeout + // so we can renew the request within auth timeout. + // auth timeout within socket timeout is thrown without backoff, + // and we need to update time remained in socket timeout here to control the + // the actual socket timeout from customer setting. + if (loginInput.getSocketTimeout() > 0) { + if (ex.issocketTimeoutNoBackoff()) { + if (leftsocketTimeout > elapsedSeconds) { + leftsocketTimeout -= elapsedSeconds; + } else { + leftsocketTimeout = 1; + } + } else { + // reset curl timeout for retry with backoff. + leftsocketTimeout = loginInput.getSocketTimeout(); + } + } + + // JWT renew should not count as a retry, so we pass back the current retry count. + retryCount = ex.getRetryCount(); + + continue; + } + } else { + throw ex; + } + } + break; + } // general method, same as with data binding JsonNode jsonNode = mapper.readTree(theString); @@ -849,7 +914,12 @@ private static SFLoginOutput tokenRequest(SFLoginInput loginInput, TokenRequestT String theString = HttpUtil.executeGeneralRequest( - postRequest, loginInput.getLoginTimeout(), loginInput.getHttpClientSettingsKey()); + postRequest, + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeout(), + 0, + loginInput.getHttpClientSettingsKey()); // general method, same as with data binding JsonNode jsonNode = mapper.readTree(theString); @@ -933,7 +1003,12 @@ static void closeSession(SFLoginInput loginInput) throws SFException, SnowflakeS String theString = HttpUtil.executeGeneralRequest( - postRequest, loginInput.getLoginTimeout(), loginInput.getHttpClientSettingsKey()); + postRequest, + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeout(), + 0, + loginInput.getHttpClientSettingsKey()); JsonNode rootNode; @@ -994,7 +1069,12 @@ private static String federatedFlowStep4( responseHtml = HttpUtil.executeGeneralRequest( - httpGet, loginInput.getLoginTimeout(), loginInput.getHttpClientSettingsKey()); + httpGet, + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeout(), + 0, + loginInput.getHttpClientSettingsKey()); // step 5 String postBackUrl = getPostBackUrlFromHTML(responseHtml); @@ -1059,6 +1139,9 @@ private static String federatedFlowStep3(SFLoginInput loginInput, String tokenUr HttpUtil.executeRequestWithoutCookies( postRequest, loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeout(), + 0, 0, null, loginInput.getHttpClientSettingsKey()); @@ -1138,7 +1221,12 @@ private static JsonNode federatedFlowStep1(SFLoginInput loginInput) throws Snowf final String gsResponse = HttpUtil.executeGeneralRequest( - postRequest, loginInput.getLoginTimeout(), loginInput.getHttpClientSettingsKey()); + postRequest, + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeout(), + 0, + loginInput.getHttpClientSettingsKey()); logger.debug("authenticator-request response: {}", gsResponse); JsonNode jsonNode = mapper.readTree(gsResponse); diff --git a/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java b/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java index 7be97656c..46a27b8aa 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java +++ b/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -179,7 +179,12 @@ private String getSSOUrl(int port) throws SFException, SnowflakeSQLException { String theString = HttpUtil.executeGeneralRequest( - postRequest, loginInput.getLoginTimeout(), loginInput.getHttpClientSettingsKey()); + postRequest, + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeout(), + 0, + loginInput.getHttpClientSettingsKey()); logger.debug("authenticator-request response: {}", theString); diff --git a/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java b/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java index 93a2999b4..8c27d2c44 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java +++ b/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java @@ -1,8 +1,10 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; +import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetEnv; + import com.google.common.base.Strings; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; @@ -55,6 +57,8 @@ class SessionUtilKeyPair { private static final String SUBJECT_FMT = "%s.%s"; + private static final int JWT_DEFAULT_AUTH_TIMEOUT = 10; + SessionUtilKeyPair( PrivateKey privateKey, String privateKeyFile, @@ -202,4 +206,13 @@ private String calculatePublicKeyFingerprint(PublicKey publicKey) throws SFExcep throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error when calculating fingerprint"); } } + + public static int getTimeout() { + String jwtAuthTimeoutStr = systemGetEnv("JWT_AUTH_TIMEOUT"); + int jwtAuthTimeout = JWT_DEFAULT_AUTH_TIMEOUT; + if (jwtAuthTimeoutStr != null) { + jwtAuthTimeout = Integer.parseInt(jwtAuthTimeoutStr); + } + return jwtAuthTimeout; + } } diff --git a/src/main/java/net/snowflake/client/core/StmtUtil.java b/src/main/java/net/snowflake/client/core/StmtUtil.java index 53bcd416b..3451f26f0 100644 --- a/src/main/java/net/snowflake/client/core/StmtUtil.java +++ b/src/main/java/net/snowflake/client/core/StmtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -81,6 +81,7 @@ static class StmtInput { Map parametersMap; String sessionToken; int networkTimeoutInMillis; + int socketTimeout; int injectSocketTimeout; // seconds int injectClientPause; // seconds @@ -162,6 +163,11 @@ public StmtInput setNetworkTimeoutInMillis(int networkTimeoutInMillis) { return this; } + public StmtInput setSocketTimeout(int socketTimeout) { + this.socketTimeout = socketTimeout; + return this; + } + public StmtInput setInjectSocketTimeout(int injectSocketTimeout) { this.injectSocketTimeout = injectSocketTimeout; return this; @@ -335,6 +341,9 @@ public static StmtOutput execute(StmtInput stmtInput) throws SFException, Snowfl HttpUtil.executeRequest( httpRequest, stmtInput.networkTimeoutInMillis / 1000, + stmtInput.socketTimeout, + 0, + 0, stmtInput.injectSocketTimeout, stmtInput.canceling, true, // include retry parameters @@ -573,6 +582,9 @@ protected static String getQueryResult(String getResultPath, StmtInput stmtInput return HttpUtil.executeRequest( httpRequest, stmtInput.networkTimeoutInMillis / 1000, + stmtInput.socketTimeout, + 0, + 0, 0, stmtInput.canceling, false, // no retry parameter @@ -605,6 +617,7 @@ protected static JsonNode getQueryResultJSON(String queryId, SFSession session) .setServerUrl(session.getServerUrl()) .setSessionToken(session.getSessionToken()) .setNetworkTimeoutInMillis(session.getNetworkTimeoutInMilli()) + .setSocketTimeout(session.getHttpClientSocketTimeout()) .setMediaType(SF_MEDIA_TYPE) .setServiceName(session.getServiceName()) .setOCSPMode(session.getOCSPMode()) @@ -686,6 +699,9 @@ public static void cancel(StmtInput stmtInput) throws SFException, SnowflakeSQLE httpRequest, SF_CANCELING_RETRY_TIMEOUT_IN_MILLIS, 0, + stmtInput.socketTimeout, + 0, + 0, null, false, // no retry parameter false, // no retry on HTTP 403 diff --git a/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java b/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java index ad97d2acf..804f516b5 100644 --- a/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java +++ b/src/main/java/net/snowflake/client/jdbc/ChunkDownloadContext.java @@ -35,6 +35,14 @@ public int getNetworkTimeoutInMilli() { return networkTimeoutInMilli; } + public int getAuthTimeout() { + return authTimeout; + } + + public int getSocketTimeout() { + return socketTimeout; + } + public SFBaseSession getSession() { return session; } @@ -44,6 +52,8 @@ public SFBaseSession getSession() { private final int chunkIndex; private final Map chunkHeadersMap; private final int networkTimeoutInMilli; + private final int authTimeout; + private final int socketTimeout; private final SFBaseSession session; public ChunkDownloadContext( @@ -53,6 +63,8 @@ public ChunkDownloadContext( int chunkIndex, Map chunkHeadersMap, int networkTimeoutInMilli, + int authTimeout, + int socketTimeout, SFBaseSession session) { this.chunkDownloader = chunkDownloader; this.resultChunk = resultChunk; @@ -60,6 +72,8 @@ public ChunkDownloadContext( this.chunkIndex = chunkIndex; this.chunkHeadersMap = chunkHeadersMap; this.networkTimeoutInMilli = networkTimeoutInMilli; + this.authTimeout = authTimeout; + this.socketTimeout = socketTimeout; this.session = session; } } diff --git a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java index 12159f8dc..8f4af77d3 100644 --- a/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/DefaultResultStreamProvider.java @@ -122,6 +122,9 @@ else if (context.getQrmk() != null) { httpClient, httpRequest, context.getNetworkTimeoutInMilli() / 1000, // retry timeout + context.getAuthTimeout(), + context.getSocketTimeout(), + 0, 0, // no socketime injection null, // no canceling false, // no cookie diff --git a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java index 7d60c3244..e0cb6d414 100644 --- a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java +++ b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java @@ -81,7 +81,8 @@ public enum ErrorCode { EXECUTE_BATCH_INTEGER_OVERFLOW(200058, SqlState.NUMERIC_VALUE_OUT_OF_RANGE), INVALID_CONNECT_STRING(200059, SqlState.CONNECTION_EXCEPTION), INVALID_OKTA_USERNAME(200060, SqlState.CONNECTION_EXCEPTION), - GCP_SERVICE_ERROR(200061, SqlState.SYSTEM_ERROR); + GCP_SERVICE_ERROR(200061, SqlState.SYSTEM_ERROR), + AUTHENTICATOR_REQUEST_TIMEOUT(200062, SqlState.CONNECTION_EXCEPTION); public static final String errorMessageResource = "net.snowflake.client.jdbc.jdbc_error_messages"; diff --git a/src/main/java/net/snowflake/client/jdbc/RestRequest.java b/src/main/java/net/snowflake/client/jdbc/RestRequest.java index ec7185035..015cb180c 100644 --- a/src/main/java/net/snowflake/client/jdbc/RestRequest.java +++ b/src/main/java/net/snowflake/client/jdbc/RestRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; @@ -54,6 +54,9 @@ public class RestRequest { * @param httpClient client object used to communicate with other machine * @param httpRequest request object contains all the request information * @param retryTimeout : retry timeout (in seconds) + * @param authTimeout : authenticator specific timeout (in seconds) + * @param socketTimeout : curl timeout (in ms) + * @param retryCount : retry count for the request * @param injectSocketTimeout : simulate socket timeout * @param canceling canceling flag * @param withoutCookies whether the cookie spec should be set to IGNORE or not @@ -68,6 +71,9 @@ public static CloseableHttpResponse execute( CloseableHttpClient httpClient, HttpRequestBase httpRequest, long retryTimeout, + long authTimeout, + int socketTimeout, + int retryCount, int injectSocketTimeout, AtomicBoolean canceling, boolean withoutCookies, @@ -98,11 +104,12 @@ public static CloseableHttpResponse execute( // amount of time to wait for backing off before retry long backoffInMilli = minBackoffInMilli; + // auth timeout (ms) + long authTimeoutInMilli = authTimeout * 1000; + DecorrelatedJitterBackoff backoff = new DecorrelatedJitterBackoff(backoffInMilli, maxBackoffInMilli); - int retryCount = 0; - int origSocketTimeout = 0; Exception savedEx = null; @@ -150,6 +157,15 @@ public static CloseableHttpResponse execute( } } + // When the auth timeout is set, set the socket timeout as the authTimeout + // so that it can be renewed in time and pass it to the http request configuration. + if (authTimeout > 0) { + int requestSocketAndConnectTimeout = (int) authTimeout * 1000; + httpRequest.setConfig( + HttpUtil.getDefaultRequestConfigWithSocketAndConnectTimeout( + requestSocketAndConnectTimeout, withoutCookies)); + } + if (includeRequestGuid) { // Add request_guid for better tracing builder.setParameter(SF_REQUEST_GUID, UUID.randomUUID().toString()); @@ -215,6 +231,8 @@ public static CloseableHttpResponse execute( Event.EventType.NETWORK_ERROR, msg + ", Request: " + httpRequest.toString(), false); } breakRetryReason = "status code does not need retry"; + // reset retryCount + retryCount = 0; break; } else { if (response != null) { @@ -259,6 +277,7 @@ public static CloseableHttpResponse execute( + "Elapsed: {}(ms), timeout: {}(ms)", elapsedMilliForTransientIssues, retryTimeoutInMilliseconds); + breakRetryReason = "retry timeout"; TelemetryService.getInstance() .logHttpRequestTelemetryEvent( @@ -284,10 +303,33 @@ public static CloseableHttpResponse execute( "Exception encountered for HTTP request: " + savedEx.getMessage()); } // no more retry + // reset state + retryCount = 0; break; } } + // Make sure that any authenticator specific info that needs to be + // updated get's updated before the next retry. Ex - JWT token + // Check to see if customer set socket/connect timeout has been reached, + // if not we don't increase the retry count since JWT renew doesn't count as a retry + // attempt. + if (authTimeout > 0 + && elapsedMilliForTransientIssues > authTimeoutInMilli + && (socketTimeout == 0 + || elapsedMilliForTransientIssues + < socketTimeout)) /* socket timeout not reached */ { + /* connect timeout not reached */ + // check if this is a login-request + if (String.valueOf(httpRequest.getURI()).contains("login-request")) { + throw new SnowflakeSQLException( + ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT, + retryCount, + true, + elapsedMilliForTransientIssues / 1000); + } + } + logger.debug("Retrying request: {}", requestInfoScrubbed); // sleep for backoff - elapsed amount of time @@ -303,6 +345,19 @@ public static CloseableHttpResponse execute( } retryCount++; + + // If the request failed with any other retry-able error and auth timeout is reached + // increase the retry count and throw special exception to renew the token before retrying. + if (authTimeout > 0) { + if (elapsedMilliForTransientIssues >= authTimeoutInMilli) { + throw new SnowflakeSQLException( + ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT, + retryCount, + false, + elapsedMilliForTransientIssues / 1000); + } + } + int numOfRetryToTriggerTelemetry = TelemetryService.getInstance().getNumOfRetryToTriggerTelemetry(); if (retryCount == numOfRetryToTriggerTelemetry) { diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java index e65d1ee92..1e37b57d4 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; @@ -84,6 +84,10 @@ public class SnowflakeChunkDownloader implements ChunkDownloader { private final int networkTimeoutInMilli; + private final int authTimeout; + + private final int socketTimeout; + private long memoryLimit; // the current memory usage across JVM @@ -190,6 +194,8 @@ public SnowflakeChunkDownloader(SnowflakeResultSetSerializableV1 resultSetSerial this.ocspModeAndProxyKey = resultSetSerializable.getHttpClientKey(); this.qrmk = resultSetSerializable.getQrmk(); this.networkTimeoutInMilli = resultSetSerializable.getNetworkTimeoutInMilli(); + this.authTimeout = resultSetSerializable.getAuthTimeout(); + this.socketTimeout = resultSetSerializable.getSocketTimeout(); this.prefetchSlots = resultSetSerializable.getResultPrefetchThreads() * 2; this.queryResultFormat = resultSetSerializable.getQueryResultFormat(); logger.debug("qrmk = {}", this.qrmk); @@ -380,6 +386,8 @@ private void startNextDownloaders() throws SnowflakeSQLException { nextChunkToDownload, chunkHeadersMap, networkTimeoutInMilli, + authTimeout, + socketTimeout, this.session)); downloaderFutures.put(nextChunkToDownload, downloaderFuture); // increment next chunk to download @@ -674,6 +682,8 @@ private void waitForChunkReady(SnowflakeResultChunk currentChunk) throws Interru nextChunkToConsume, chunkHeadersMap, networkTimeoutInMilli, + authTimeout, + socketTimeout, session)); downloaderFutures.put(nextChunkToDownload, downloaderFuture); // Only when prefetch fails due to internal memory limitation, nextChunkToDownload @@ -827,6 +837,8 @@ private static Callable getDownloadChunkCallable( final int chunkIndex, final Map chunkHeadersMap, final int networkTimeoutInMilli, + final int authTimeout, + final int socketTimeout, final SFBaseSession session) { ChunkDownloadContext downloadContext = new ChunkDownloadContext( @@ -836,6 +848,8 @@ private static Callable getDownloadChunkCallable( chunkIndex, chunkHeadersMap, networkTimeoutInMilli, + authTimeout, + socketTimeout, session); return new Callable() { diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java index d31c9e018..2089ee579 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; @@ -115,6 +115,8 @@ public String toString() { OCSPMode ocspMode; HttpClientSettingsKey httpClientKey; int networkTimeoutInMilli; + int authTimeout; + int socketTimeout; boolean isResultColumnCaseInsensitive; int resultSetType; int resultSetConcurrency; @@ -187,6 +189,8 @@ private SnowflakeResultSetSerializableV1(SnowflakeResultSetSerializableV1 toCopy this.ocspMode = toCopy.ocspMode; this.httpClientKey = toCopy.httpClientKey; this.networkTimeoutInMilli = toCopy.networkTimeoutInMilli; + this.authTimeout = toCopy.authTimeout; + this.socketTimeout = toCopy.socketTimeout; this.isResultColumnCaseInsensitive = toCopy.isResultColumnCaseInsensitive; this.resultSetType = toCopy.resultSetType; this.resultSetConcurrency = toCopy.resultSetConcurrency; @@ -297,6 +301,14 @@ public int getNetworkTimeoutInMilli() { return networkTimeoutInMilli; } + public int getAuthTimeout() { + return authTimeout; + } + + public int getSocketTimeout() { + return socketTimeout; + } + public int getResultPrefetchThreads() { return resultPrefetchThreads; } @@ -616,6 +628,7 @@ public static SnowflakeResultSetSerializableV1 create( resultSetSerializable.httpClientKey = sfSession.getHttpClientKey(); resultSetSerializable.snowflakeConnectionString = sfSession.getSnowflakeConnectionString(); resultSetSerializable.networkTimeoutInMilli = sfSession.getNetworkTimeoutInMilli(); + resultSetSerializable.authTimeout = sfSession.getAuthTimeout(); resultSetSerializable.isResultColumnCaseInsensitive = sfSession.isResultColumnCaseInsensitive(); resultSetSerializable.treatNTZAsUTC = sfSession.getTreatNTZAsUTC(); resultSetSerializable.formatDateWithTimezone = sfSession.getFormatDateWithTimezone(); diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeSQLException.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeSQLException.java index defad54dc..65144a446 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeSQLException.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeSQLException.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; @@ -20,6 +20,10 @@ public class SnowflakeSQLException extends SQLException { ResourceBundleManager.getSingleton(ErrorCode.errorMessageResource); private String queryId = "unknown"; + private int retryCount = 0; + + boolean issocketTimeoutNoBackoff; + long elapsedSeconds; /** * This constructor should only be used for error from Global service. Since Global service has @@ -115,6 +119,17 @@ public SnowflakeSQLException(ErrorCode errorCode, Object... params) { errorCode.getMessageCode()); } + public SnowflakeSQLException( + ErrorCode errorCode, int retryCount, boolean issocketTimeoutNoBackoff, long elapsedSeconds) { + super( + errorResourceBundleManager.getLocalizedMessage(String.valueOf(errorCode.getMessageCode())), + errorCode.getSqlState(), + errorCode.getMessageCode()); + this.retryCount = retryCount; + this.issocketTimeoutNoBackoff = issocketTimeoutNoBackoff; + this.elapsedSeconds = elapsedSeconds; + } + public SnowflakeSQLException(SFException e) { this(e.getQueryId(), e.getMessage(), e.getSqlState(), e.getVendorCode()); } @@ -126,4 +141,16 @@ public SnowflakeSQLException(String reason) { public String getQueryId() { return queryId; } + + public int getRetryCount() { + return retryCount; + } + + public boolean issocketTimeoutNoBackoff() { + return issocketTimeoutNoBackoff; + } + + public long getElapsedSeconds() { + return elapsedSeconds; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java index a7cbfcaab..dc9d6e2e7 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc.cloud.storage; @@ -233,6 +233,9 @@ public void download( httpClient, httpRequest, session.getNetworkTimeoutInMilli() / 1000, // retry timeout + session.getAuthTimeout(), + session.getHttpClientSocketTimeout(), + 0, 0, // no socketime injection null, // no canceling false, // no cookie @@ -395,6 +398,9 @@ public InputStream downloadToStream( httpClient, httpRequest, session.getNetworkTimeoutInMilli() / 1000, // retry timeout + session.getAuthTimeout(), + session.getHttpClientSocketTimeout(), + 0, 0, // no socketime injection null, // no canceling false, // no cookie @@ -737,6 +743,9 @@ private void uploadWithPresignedUrl( httpClient, httpRequest, networkTimeoutInMilli / 1000, // retry timeout + session.getAuthTimeout(), + session.getHttpClientSocketTimeout(), + 0, 0, // no socketime injection null, // no canceling false, // no cookie diff --git a/src/main/java/net/snowflake/client/jdbc/telemetry/TelemetryClient.java b/src/main/java/net/snowflake/client/jdbc/telemetry/TelemetryClient.java index efdc4fddb..d22b546b5 100644 --- a/src/main/java/net/snowflake/client/jdbc/telemetry/TelemetryClient.java +++ b/src/main/java/net/snowflake/client/jdbc/telemetry/TelemetryClient.java @@ -327,9 +327,19 @@ private boolean sendBatch() throws IOException { response = this.session == null ? HttpUtil.executeGeneralRequest( - post, TELEMETRY_HTTP_RETRY_TIMEOUT_IN_SEC, this.httpClient) + post, + TELEMETRY_HTTP_RETRY_TIMEOUT_IN_SEC, + this.session.getAuthTimeout(), + this.session.getHttpClientSocketTimeout(), + 0, + this.httpClient) : HttpUtil.executeGeneralRequest( - post, TELEMETRY_HTTP_RETRY_TIMEOUT_IN_SEC, this.session.getHttpClientKey()); + post, + TELEMETRY_HTTP_RETRY_TIMEOUT_IN_SEC, + this.session.getAuthTimeout(), + this.session.getHttpClientSocketTimeout(), + 0, + this.session.getHttpClientKey()); } catch (SnowflakeSQLException e) { disableTelemetry(); // when got error like 404 or bad request, disable telemetry in this // telemetry instance diff --git a/src/test/java/net/snowflake/client/core/SessionUtilExternalBrowserTest.java b/src/test/java/net/snowflake/client/core/SessionUtilExternalBrowserTest.java index 54a86d145..764719275 100644 --- a/src/test/java/net/snowflake/client/core/SessionUtilExternalBrowserTest.java +++ b/src/test/java/net/snowflake/client/core/SessionUtilExternalBrowserTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2020 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -137,6 +137,9 @@ public void testSessionUtilExternalBrowser() throws Throwable { HttpUtil.executeGeneralRequest( Mockito.any(HttpRequestBase.class), Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), Mockito.nullable(HttpClientSettingsKey.class))) .thenReturn( "{\"success\":\"true\",\"data\":{\"proofKey\":\"" @@ -173,6 +176,9 @@ public void testSessionUtilExternalBrowserFail() throws Throwable { HttpUtil.executeGeneralRequest( Mockito.any(HttpRequestBase.class), Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), Mockito.nullable(HttpClientSettingsKey.class))) .thenReturn("{\"success\":\"false\",\"code\":\"123456\",\"message\":\"errormes\"}"); diff --git a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java new file mode 100644 index 000000000..1c5eb794c --- /dev/null +++ b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core; + +import static net.snowflake.client.TestUtil.systemGetEnv; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.Map; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.common.core.ClientAuthnDTO; +import org.apache.http.client.methods.HttpRequestBase; +import org.junit.Ignore; +import org.junit.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +public class SessionUtilLatestIT { + + /** + * Tests the JWT renew functionality when retrying login requests. To run, update environment + * variables to use connect with JWT authentication. + * + * @throws SFException + * @throws SnowflakeSQLException + */ + @Ignore + @Test + public void testJwtAuthTimeoutRetry() throws SFException, SnowflakeSQLException { + final SFLoginInput loginInput = initMockLoginInput(); + Map connectionPropertiesMap = initConnectionPropertiesMap(); + MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class); + SnowflakeSQLException ex = + new SnowflakeSQLException(ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT, 0, true, 0); + + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + Mockito.any(HttpRequestBase.class), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.nullable(HttpClientSettingsKey.class))) + .thenThrow(ex) // fail first + .thenReturn( + "{\"data\":null,\"code\":null,\"message\":null,\"success\":true}"); // succeed on retry + + SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL"); + } + + /** + * Mock SFLoginInput + * + * @return a mock object for SFLoginInput + */ + private SFLoginInput initMockLoginInput() { + // mock SFLoginInput + SFLoginInput loginInput = mock(SFLoginInput.class); + when(loginInput.getServerUrl()).thenReturn(systemGetEnv("SNOWFLAKE_TEST_HOST")); + when(loginInput.getAuthenticator()) + .thenReturn(ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT.name()); + when(loginInput.getPrivateKeyFile()) + .thenReturn(systemGetEnv("SNOWFLAKE_TEST_PRIVATE_KEY_FILE")); + when(loginInput.getPrivateKeyFilePwd()) + .thenReturn(systemGetEnv("SNOWFLAKE_TEST_PRIVATE_KEY_FILE_PWD")); + when(loginInput.getUserName()).thenReturn(systemGetEnv("SNOWFLAKE_TEST_USER")); + when(loginInput.getAccountName()).thenReturn("testaccount"); + when(loginInput.getAppId()).thenReturn("testid"); + when(loginInput.getOCSPMode()).thenReturn(OCSPMode.FAIL_OPEN); + when(loginInput.getHttpClientSettingsKey()) + .thenReturn(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); + return loginInput; + } + + /** + * Initialize the connection properties map. + * + * @return connectionPropertiesMap + */ + private Map initConnectionPropertiesMap() { + Map connectionPropertiesMap = new HashMap<>(); + connectionPropertiesMap.put(SFSessionProperty.TRACING, "ALL"); + return connectionPropertiesMap; + } +} diff --git a/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java b/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java index 8e1bfeff3..0dd691ed3 100644 --- a/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java +++ b/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.core; @@ -101,7 +101,12 @@ public void testMFAFunctionality() throws SQLException { .when( () -> HttpUtil.executeGeneralRequest( - any(HttpPost.class), anyInt(), any(HttpClientSettingsKey.class))) + any(HttpPost.class), + anyInt(), + anyInt(), + anyInt(), + anyInt(), + any(HttpClientSettingsKey.class))) .thenAnswer( new Answer() { int callCount = 0; @@ -245,7 +250,12 @@ private void unavailableLSSWindowsTestBody() throws SQLException { .when( () -> HttpUtil.executeGeneralRequest( - any(HttpPost.class), anyInt(), any(HttpClientSettingsKey.class))) + any(HttpPost.class), + anyInt(), + anyInt(), + anyInt(), + anyInt(), + any(HttpClientSettingsKey.class))) .thenAnswer( new Answer() { int callCount = 0; diff --git a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java index 558fe42b5..2eef32f6c 100644 --- a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2020 Snowflake Computing Inc. All right reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All right reserved. */ package net.snowflake.client.jdbc; @@ -933,7 +933,7 @@ public void testAuthenticatorEndpointWithDashInAccountName() throws Exception { postRequest.addHeader("accept", "application/json"); String theString = - HttpUtil.executeGeneralRequest(postRequest, 60, new HttpClientSettingsKey(null)); + HttpUtil.executeGeneralRequest(postRequest, 60, 0, 0, 0, new HttpClientSettingsKey(null)); JsonNode jsonNode = mapper.readTree(theString); assertEquals( diff --git a/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java b/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java index 6a652aeaa..666653254 100644 --- a/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java +++ b/src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java @@ -701,9 +701,23 @@ public int getNetworkTimeoutInMilli() { return 0; } + public int getAuthTimeout() { + return 0; + } + public SnowflakeConnectString getSnowflakeConnectionString() { return null; } + + @Override + public int getHttpClientConnectionTimeout() { + return 0; + } + + @Override + public int getHttpClientSocketTimeout() { + return 0; + } } private static class MockSFFileTransferAgent extends SFBaseFileTransferAgent { diff --git a/src/test/java/net/snowflake/client/jdbc/RestRequestTest.java b/src/test/java/net/snowflake/client/jdbc/RestRequestTest.java index a1a7a86ac..2ef43be24 100644 --- a/src/test/java/net/snowflake/client/jdbc/RestRequestTest.java +++ b/src/test/java/net/snowflake/client/jdbc/RestRequestTest.java @@ -1,8 +1,10 @@ /* - * Copyright (c) 2012-2020 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -10,7 +12,9 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import net.snowflake.client.core.HttpUtil; import org.apache.http.StatusLine; +import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpUriRequest; @@ -21,6 +25,10 @@ /** RestRequest unit tests. */ public class RestRequestTest { + + static final int DEFAULT_CONNECTION_TIMEOUT = 60000; + static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT = 300000; // ms + private CloseableHttpResponse retryResponse() { StatusLine retryStatusLine = mock(StatusLine.class); when(retryStatusLine.getStatusCode()).thenReturn(503); @@ -41,12 +49,31 @@ private CloseableHttpResponse successResponse() { return successResponse; } - private void execute(CloseableHttpClient client, String uri, boolean includeRetryParameters) + private void execute( + CloseableHttpClient client, + String uri, + int retryTimeout, + int authTimeout, + int socketTimeout, + boolean includeRetryParameters) throws IOException, SnowflakeSQLException { + + RequestConfig.Builder builder = + RequestConfig.custom() + .setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT) + .setConnectionRequestTimeout(DEFAULT_CONNECTION_TIMEOUT) + .setSocketTimeout(DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT); + RequestConfig defaultRequestConfig = builder.build(); + HttpUtil util = new HttpUtil(); + util.setRequestConfig(defaultRequestConfig); + RestRequest.execute( client, new HttpGet(uri), - 0, // retry timeout + retryTimeout, // retry timeout + authTimeout, + socketTimeout, + 0, 0, // inject socket timeout new AtomicBoolean(false), // canceling false, // without cookie @@ -87,7 +114,7 @@ public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwabl } }); - execute(client, "fakeurl.com/?requestId=abcd-1234", true); + execute(client, "fakeurl.com/?requestId=abcd-1234", 0, 0, 0, true); } @Test @@ -118,7 +145,7 @@ public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwabl } }); - execute(client, "fakeurl.com/?requestId=abcd-1234", false); + execute(client, "fakeurl.com/?requestId=abcd-1234", 0, 0, 0, false); } private CloseableHttpResponse anyStatusCodeResponse(int statusCode) { @@ -270,4 +297,18 @@ class TestCase { } } } + + @Test + public void testExceptionAuthBasedTimeout() throws IOException { + CloseableHttpClient client = mock(CloseableHttpClient.class); + when(client.execute(any(HttpUriRequest.class))) + .thenAnswer((Answer) invocation -> retryResponse()); + + try { + execute(client, "login-request.com/?requestId=abcd-1234", 2, 1, 30000, true); + } catch (SnowflakeSQLException ex) { + assertThat( + ex.getErrorCode(), equalTo(ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT.getMessageCode())); + } + } } diff --git a/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java b/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java index d79bf9d53..22e846e5b 100644 --- a/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java +++ b/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2020 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; @@ -217,7 +217,12 @@ private void initMockHttpUtil(MockedStatic mockedHttpUtil) throws IOEx .when( () -> HttpUtil.executeGeneralRequest( - any(HttpPost.class), anyInt(), nullable(HttpClientSettingsKey.class))) + any(HttpPost.class), + anyInt(), + anyInt(), + anyInt(), + anyInt(), + nullable(HttpClientSettingsKey.class))) .thenAnswer( new Answer() { int callCount = 0; diff --git a/src/test/java/net/snowflake/client/jdbc/ServiceNameTest.java b/src/test/java/net/snowflake/client/jdbc/ServiceNameTest.java index 6a7d281b2..28a756b2d 100644 --- a/src/test/java/net/snowflake/client/jdbc/ServiceNameTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ServiceNameTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2020 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.client.jdbc; @@ -99,6 +99,9 @@ public void testAddServiceNameToRequestHeader() throws Throwable { HttpUtil.executeGeneralRequest( Mockito.any(HttpRequestBase.class), Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), Mockito.any(HttpClientSettingsKey.class))) .thenReturn(responseLogin()); mockedHttpUtil @@ -108,6 +111,9 @@ public void testAddServiceNameToRequestHeader() throws Throwable { Mockito.any(HttpRequestBase.class), Mockito.anyInt(), Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), + Mockito.anyInt(), Mockito.any(AtomicBoolean.class), Mockito.anyBoolean(), Mockito.anyBoolean(),