diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 6cfe207f..acd6f3b8 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -474,6 +474,14 @@ def closed(self) -> bool: """Returns True if the connection is closed.""" return not self.opened() + def get_current_refresh_token(self) -> str: + """Returns the current refresh token. + + This may be different from the user supplied token if token refresh + was required and token rotation is in effect + """ + return self.oauth_refresh_token + def __str__(self) -> str: safe_options = {key: value for key, value in self.options.items() if key != 'password'} @@ -920,7 +928,7 @@ def startup_connection(self) -> bool: # If access token is not set, will attempt to set a new one by using token refresh if len(self.oauth_access_token) == 0 and self.oauth_manager and not self.oauth_manager.refresh_attempted: self._logger.info("Issuing an OAuth access token using a refresh token") - self.oauth_access_token = self.oauth_manager.do_token_refresh() + self.oauth_access_token, self.oauth_refresh_token = self.oauth_manager.do_token_refresh() self.write(messages.Password(self.oauth_access_token, message.code)) else: self.write(messages.Password(password, message.code, @@ -940,7 +948,7 @@ def startup_connection(self) -> bool: raise errors.ConnectionError("Did not receive proper OAuth Authentication response from server. Please upgrade to the latest Vertica server for OAuth Support.") self.close_socket() self._logger.info("Issuing a new OAuth access token using a refresh token") - self.oauth_access_token = self.oauth_manager.do_token_refresh() + self.oauth_access_token, self.oauth_refresh_token = self.oauth_manager.do_token_refresh() return True raise errors.ConnectionError(message.error_message()) else: diff --git a/vertica_python/vertica/oauth_manager.py b/vertica_python/vertica/oauth_manager.py index 87e39da9..1df1af13 100644 --- a/vertica_python/vertica/oauth_manager.py +++ b/vertica_python/vertica/oauth_manager.py @@ -80,7 +80,12 @@ def get_access_token_using_refresh_token(self) -> str: # TODO handle self.validate_cert_hostname response = requests.post(self.token_url, headers=headers, data=params, verify=False) response.raise_for_status() - return response.json()["access_token"] + json_response = response.json() + # If refresh token rotation is used, like in OTDS, we will get both our new valid access token as well as + # a new refresh token to use the next time we need to invoke token refresh. + if 'refresh_token' in json_response: + self.refresh_token = json_response["refresh_token"] + return response.json()["access_token"], self.refresh_token except requests.exceptions.HTTPError as err: msg = f'{err_msg}\n{err}\n{response.json()}' raise OAuthTokenRefreshError(msg)