From 3aa9cf627d5c48d8e317348ee11e65b2bf370b82 Mon Sep 17 00:00:00 2001 From: Steve Clarke <84364906+s7clarke10@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:56:32 +1200 Subject: [PATCH 1/2] Supporting optional headers for oauth. --- singer_sdk/authenticators.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/singer_sdk/authenticators.py b/singer_sdk/authenticators.py index ac9bb2807..cf74cac80 100644 --- a/singer_sdk/authenticators.py +++ b/singer_sdk/authenticators.py @@ -358,6 +358,7 @@ def __init__( auth_endpoint: str | None = None, oauth_scopes: str | None = None, default_expiration: int | None = None, + oauth_headers: dict | None = None, ) -> None: """Create a new authenticator. @@ -377,6 +378,10 @@ def __init__( self.refresh_token: str | None = None self.last_refreshed: datetime | None = None self.expires_in: int | None = None + if self._auth_headers is None: + self._auth_headers = {} + if oauth_headers: + self._auth_headers.update(oauth_headers) @property def auth_headers(self) -> dict: @@ -499,6 +504,7 @@ def update_access_token(self) -> None: auth_request_payload = self.oauth_request_payload token_response = requests.post( self.auth_endpoint, + headers=self._auth_headers, data=auth_request_payload, timeout=60, ) @@ -512,7 +518,7 @@ def update_access_token(self) -> None: token_json = token_response.json() self.access_token = token_json["access_token"] - self.expires_in = token_json.get("expires_in", self._default_expiration) + self.expires_in = int(token_json.get("expires_in", self._default_expiration)) if self.expires_in is None: self.logger.debug( "No expires_in receied in OAuth response and no " From 430a2280a1922e3ae2adca5c8b44ee1900856a88 Mon Sep 17 00:00:00 2001 From: Steve Clarke <84364906+s7clarke10@users.noreply.github.com> Date: Mon, 10 Jul 2023 11:54:39 +1200 Subject: [PATCH 2/2] Optional OAuth only Headers with tests for None --- singer_sdk/authenticators.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/singer_sdk/authenticators.py b/singer_sdk/authenticators.py index cf74cac80..61382daba 100644 --- a/singer_sdk/authenticators.py +++ b/singer_sdk/authenticators.py @@ -367,21 +367,19 @@ def __init__( auth_endpoint: The OAuth 2.0 authorization endpoint. oauth_scopes: A comma-separated list of OAuth scopes. default_expiration: Default token expiry in seconds. + oauth_headers: An optional dict of headers required to get a token. """ super().__init__(stream=stream) self._auth_endpoint = auth_endpoint self._default_expiration = default_expiration self._oauth_scopes = oauth_scopes + self._oauth_headers = oauth_headers or {} # Initialize internal tracking attributes self.access_token: str | None = None self.refresh_token: str | None = None self.last_refreshed: datetime | None = None self.expires_in: int | None = None - if self._auth_headers is None: - self._auth_headers = {} - if oauth_headers: - self._auth_headers.update(oauth_headers) @property def auth_headers(self) -> dict: @@ -504,7 +502,7 @@ def update_access_token(self) -> None: auth_request_payload = self.oauth_request_payload token_response = requests.post( self.auth_endpoint, - headers=self._auth_headers, + headers=self._oauth_headers, data=auth_request_payload, timeout=60, ) @@ -518,7 +516,8 @@ def update_access_token(self) -> None: token_json = token_response.json() self.access_token = token_json["access_token"] - self.expires_in = int(token_json.get("expires_in", self._default_expiration)) + expiration = token_json.get("expires_in", self._default_expiration) + self.expires_in = int(expiration) if expiration else None if self.expires_in is None: self.logger.debug( "No expires_in receied in OAuth response and no "