Skip to content

Commit

Permalink
Make optional OAuth2 request parameters configurable (apache#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
himadripal authored Mar 6, 2024
1 parent 14c021b commit 29fd42c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
2 changes: 2 additions & 0 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ catalog:
| credential | t-1234:secret | Credential to use for OAuth2 credential flow when initializing the catalog |
| token | FEW23.DFSDF.FSDF | Bearer token value to use for `Authorization` header |
| scope | openid offline corpds:ds:profile | Desired scope of the requested security token (default : catalog) |
| resource | rest_catalog.iceberg.com | URI for the target resource or service |
| audience | rest_catalog | Logical name of target resource or service |
| rest.sigv4-enabled | true | Sign requests to the REST Server using AWS SigV4 protocol |
| rest.signing-region | us-east-1 | The region to use when SigV4 signing a request |
| rest.signing-name | execute-api | The service signing name to use when SigV4 signing a request |
Expand Down
18 changes: 15 additions & 3 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class Endpoints:
CREDENTIAL = "credential"
GRANT_TYPE = "grant_type"
SCOPE = "scope"
AUDIENCE = "audience"
RESOURCE = "resource"
TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange"
SEMICOLON = ":"
KEY = "key"
Expand Down Expand Up @@ -289,16 +291,26 @@ def auth_url(self) -> str:
else:
return self.url(Endpoints.get_token, prefixed=False)

def _extract_optional_oauth_params(self) -> Dict[str, str]:
optional_oauth_param = {SCOPE: self.properties.get(SCOPE) or CATALOG_SCOPE}
set_of_optional_params = {AUDIENCE, RESOURCE}
for param in set_of_optional_params:
if param_value := self.properties.get(param):
optional_oauth_param[param] = param_value

return optional_oauth_param

def _fetch_access_token(self, session: Session, credential: str) -> str:
if SEMICOLON in credential:
client_id, client_secret = credential.split(SEMICOLON)
else:
client_id, client_secret = None, credential

# take scope from properties or use default CATALOG_SCOPE
scope = self.properties.get(SCOPE) or CATALOG_SCOPE
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret}

optional_oauth_params = self._extract_optional_oauth_params()
data.update(optional_oauth_params)

data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: scope}
response = session.post(
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
)
Expand Down
45 changes: 45 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
TEST_AUTH_URL = "https://auth-endpoint/"
TEST_TOKEN = "some_jwt_token"
TEST_SCOPE = "openid_offline_corpds_ds_profile"
TEST_AUDIENCE = "test_audience"
TEST_RESOURCE = "test_resource"

TEST_HEADERS = {
"Content-type": "application/json",
"X-Client-Version": "0.14.1",
Expand Down Expand Up @@ -137,6 +140,48 @@ def test_token_200_without_optional_fields(rest_mock: Mocker) -> None:
)


def test_token_with_optional_oauth_params(rest_mock: Mocker) -> None:
mock_request = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
json={
"access_token": TEST_TOKEN,
"token_type": "Bearer",
"expires_in": 86400,
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
},
status_code=200,
request_headers=OAUTH_TEST_HEADERS,
)
assert (
RestCatalog(
"rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience=TEST_AUDIENCE, resource=TEST_RESOURCE
)._session.headers["Authorization"]
== f"Bearer {TEST_TOKEN}"
)
assert TEST_AUDIENCE in mock_request.last_request.text
assert TEST_RESOURCE in mock_request.last_request.text


def test_token_with_optional_oauth_params_as_empty(rest_mock: Mocker) -> None:
mock_request = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
json={
"access_token": TEST_TOKEN,
"token_type": "Bearer",
"expires_in": 86400,
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
},
status_code=200,
request_headers=OAUTH_TEST_HEADERS,
)
assert (
RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience="", resource="")._session.headers["Authorization"]
== f"Bearer {TEST_TOKEN}"
)
assert TEST_AUDIENCE not in mock_request.last_request.text
assert TEST_RESOURCE not in mock_request.last_request.text


def test_token_with_default_scope(rest_mock: Mocker) -> None:
mock_request = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
Expand Down

0 comments on commit 29fd42c

Please sign in to comment.