diff --git a/singer_sdk/authenticators.py b/singer_sdk/authenticators.py index ac9bb2807..61382daba 100644 --- a/singer_sdk/authenticators.py +++ b/singer_sdk/authenticators.py @@ -358,6 +358,7 @@ def __init__( auth_endpoint: str | None = None, oauth_scopes: str | None = None, default_expiration: int | None = None, + oauth_headers: dict | None = None, ) -> None: """Create a new authenticator. @@ -366,11 +367,13 @@ def __init__( auth_endpoint: The OAuth 2.0 authorization endpoint. oauth_scopes: A comma-separated list of OAuth scopes. default_expiration: Default token expiry in seconds. + oauth_headers: An optional dict of headers required to get a token. """ super().__init__(stream=stream) self._auth_endpoint = auth_endpoint self._default_expiration = default_expiration self._oauth_scopes = oauth_scopes + self._oauth_headers = oauth_headers or {} # Initialize internal tracking attributes self.access_token: str | None = None @@ -499,6 +502,7 @@ def update_access_token(self) -> None: auth_request_payload = self.oauth_request_payload token_response = requests.post( self.auth_endpoint, + headers=self._oauth_headers, data=auth_request_payload, timeout=60, ) @@ -512,7 +516,8 @@ def update_access_token(self) -> None: token_json = token_response.json() self.access_token = token_json["access_token"] - self.expires_in = token_json.get("expires_in", self._default_expiration) + expiration = token_json.get("expires_in", self._default_expiration) + self.expires_in = int(expiration) if expiration else None if self.expires_in is None: self.logger.debug( "No expires_in receied in OAuth response and no "