Skip to content

Commit

Permalink
[#3755] improvement(client-python): Support OAuth2TokenProvider for P…
Browse files Browse the repository at this point in the history
…ython client (#4011)

### What changes were proposed in this pull request?

* Add `OAuth2TokenProvider` and `DefaultOAuth2TokenProvider` in
`client-python`
* There are some components and tests missing because it would be a big
code change if they were also done in this PR, they will be added in the
following PRs
	- [ ] Error Handling: #4173 
	- [ ] Integration Test: #4208 
* Modify test file structure, and found issue #4136, solve it by reset
environment variable.

### Why are the changes needed?

Fix: #3755, #4136

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Add UT and tested by `./gradlew clients:client-python:unittest`

---------

Co-authored-by: TimWang <[email protected]>
  • Loading branch information
noidname01 and TimWang authored Jul 19, 2024
1 parent a6e7073 commit 213bcc9
Show file tree
Hide file tree
Showing 13 changed files with 590 additions and 12 deletions.
2 changes: 2 additions & 0 deletions clients/client-python/gravitino/auth/auth_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@
class AuthConstants:
HTTP_HEADER_AUTHORIZATION: str = "Authorization"

AUTHORIZATION_BEARER_HEADER: str = "Bearer "

AUTHORIZATION_BASIC_HEADER: str = "Basic "
133 changes: 133 additions & 0 deletions clients/client-python/gravitino/auth/default_oauth2_token_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

import time
import json
import base64
from typing import Optional
from gravitino.auth.oauth2_token_provider import OAuth2TokenProvider
from gravitino.dto.responses.oauth2_token_response import OAuth2TokenResponse
from gravitino.dto.requests.oauth2_client_credential_request import (
OAuth2ClientCredentialRequest,
)
from gravitino.exceptions.base import GravitinoRuntimeException

CLIENT_CREDENTIALS = "client_credentials"
CREDENTIAL_SPLITTER = ":"
TOKEN_SPLITTER = "."
JWT_EXPIRE = "exp"


class DefaultOAuth2TokenProvider(OAuth2TokenProvider):
"""This class is the default implement of OAuth2TokenProvider."""

_credential: Optional[str]
_scope: Optional[str]
_path: Optional[str]
_token: Optional[str]

def __init__(
self,
uri: str = None,
credential: str = None,
scope: str = None,
path: str = None,
):
super().__init__(uri)

self._credential = credential
self._scope = scope
self._path = path

self.validate()

self._token = self._fetch_token()

def validate(self):
assert (
self._credential and self._credential.strip()
), "OAuth2TokenProvider must set credential"
assert self._scope and self._scope.strip(), "OAuth2TokenProvider must set scope"
assert self._path and self._path.strip(), "OAuth2TokenProvider must set path"

def _get_access_token(self) -> Optional[str]:

expires = self._expires_at_millis()

if expires is None:
return None

if expires > time.time() * 1000:
return self._token

self._token = self._fetch_token()
return self._token

def _parse_credential(self):
assert self._credential is not None, "Invalid credential: None"

credential_info = self._credential.split(CREDENTIAL_SPLITTER, maxsplit=1)
client_id = None
client_secret = None

if len(credential_info) == 2:
client_id, client_secret = credential_info
elif len(credential_info) == 1:
client_secret = credential_info[0]
else:
raise GravitinoRuntimeException(f"Invalid credential: {self._credential}")

return client_id, client_secret

def _fetch_token(self) -> str:

client_id, client_secret = self._parse_credential()

client_credential_request = OAuth2ClientCredentialRequest(
grant_type=CLIENT_CREDENTIALS,
client_id=client_id,
client_secret=client_secret,
scope=self._scope,
)

resp = self._client.post_form(
self._path, data=client_credential_request.to_dict()
)
oauth2_resp = OAuth2TokenResponse.from_json(resp.body, infer_missing=True)
oauth2_resp.validate()

return oauth2_resp.access_token()

def _expires_at_millis(self) -> int:
if self._token is None:
return None

parts = self._token.split(TOKEN_SPLITTER)

if len(parts) != 3:
return None

jwt = json.loads(
base64.b64decode(parts[1] + "=" * (-len(parts[1]) % 4)).decode("utf-8")
)

if JWT_EXPIRE not in jwt or not isinstance(jwt[JWT_EXPIRE], int):
return None

return jwt[JWT_EXPIRE] * 1000
75 changes: 75 additions & 0 deletions clients/client-python/gravitino/auth/oauth2_token_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

from abc import abstractmethod
from typing import Optional

from gravitino.utils.http_client import HTTPClient
from gravitino.auth.auth_data_provider import AuthDataProvider
from gravitino.auth.auth_constants import AuthConstants


class OAuth2TokenProvider(AuthDataProvider):
"""OAuth2TokenProvider will request the access token from the authorization server and then provide
the access token for every request.
"""

# The HTTP client used to request the access token from the authorization server.
_client: HTTPClient

def __init__(self, uri: str):
self._client = HTTPClient(uri)

def has_token_data(self) -> bool:
"""Judge whether AuthDataProvider can provide token data.
Returns:
true if the AuthDataProvider can provide token data otherwise false.
"""
return True

def get_token_data(self) -> Optional[bytes]:
"""Acquire the data of token for authentication. The client will set the token data as HTTP header
Authorization directly. So the return value should ensure token data contain the token header
(eg: Bearer, Basic) if necessary.
Returns:
the token data is used for authentication.
"""
access_token = self._get_access_token()

if access_token is None:
return None

return (AuthConstants.AUTHORIZATION_BEARER_HEADER + access_token).encode(
"utf-8"
)

def close(self):
"""Closes the OAuth2TokenProvider and releases any underlying resources."""
if self._client is not None:
self._client.close()

@abstractmethod
def _get_access_token(self) -> Optional[str]:
"""Get the access token from the authorization server."""

@abstractmethod
def validate(self):
"""Validate the OAuth2TokenProvider"""
4 changes: 2 additions & 2 deletions clients/client-python/gravitino/auth/simple_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import base64
import os

from .auth_constants import AuthConstants
from .auth_data_provider import AuthDataProvider
from gravitino.auth.auth_constants import AuthConstants
from gravitino.auth.auth_data_provider import AuthDataProvider


class SimpleAuthProvider(AuthDataProvider):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

from typing import Optional
from dataclasses import dataclass


@dataclass
class OAuth2ClientCredentialRequest:

grant_type: str
client_id: Optional[str]
client_secret: str
scope: str

def to_dict(self, **kwarg):
return {k: v for k, v in self.__dict__.items() if v is not None}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

from typing import Optional
from dataclasses import dataclass, field
from dataclasses_json import config

from gravitino.dto.responses.base_response import BaseResponse
from gravitino.auth.auth_constants import AuthConstants


@dataclass
class OAuth2TokenResponse(BaseResponse):

_access_token: str = field(metadata=config(field_name="access_token"))
_issue_token_type: Optional[str] = field(
metadata=config(field_name="issued_token_type")
)
_token_type: str = field(metadata=config(field_name="token_type"))
_expires_in: int = field(metadata=config(field_name="expires_in"))
_scope: str = field(metadata=config(field_name="scope"))
_refresh_token: Optional[str] = field(metadata=config(field_name="refresh_token"))

def validate(self):
"""Validates the response.
Raise:
IllegalArgumentException If the response is invalid, this exception is thrown.
"""
super().validate()

assert self._access_token is not None, "Invalid access token: None"
assert (
AuthConstants.AUTHORIZATION_BEARER_HEADER.strip().lower()
== self._token_type.lower()
), f'Unsupported token type: {self._token_type} (must be "bearer")'

def access_token(self) -> str:
return self._access_token
36 changes: 27 additions & 9 deletions clients/client-python/gravitino/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def json(self):


class HTTPClient:

FORMDATA_HEADER = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/vnd.gravitino.v1+json",
}

JSON_HEADER = {
"Content-Type": "application/json",
"Accept": "application/vnd.gravitino.v1+json",
}

def __init__(
self,
host,
Expand Down Expand Up @@ -139,30 +150,32 @@ def _make_request(self, opener, request, timeout=None) -> Tuple[bool, Response]:

return (False, err_resp)

# pylint: disable=too-many-locals
def _request(
self,
method,
endpoint,
params=None,
json=None,
data=None,
headers=None,
timeout=None,
error_handler: ErrorHandler = None,
):
method = method.upper()
request_data = None

if headers:
self._update_headers(headers)
if data:
request_data = urlencode(data.to_dict()).encode()
self._update_headers(self.FORMDATA_HEADER)
else:
headers = {
"Content-Type": "application/json",
"Accept": "application/vnd.gravitino.v1+json",
}
self._update_headers(headers)
if json:
request_data = json.to_json().encode("utf-8")

if json:
request_data = json.to_json().encode("utf-8")
self._update_headers(self.JSON_HEADER)

if headers:
self._update_headers(headers)

opener = build_opener()
request = Request(self._build_url(endpoint, params), data=request_data)
Expand Down Expand Up @@ -213,6 +226,11 @@ def put(self, endpoint, json=None, error_handler=None, **kwargs):
"put", endpoint, json=json, error_handler=error_handler, **kwargs
)

def post_form(self, endpoint, data=None, error_handler=None, **kwargs):
return self._request(
"post", endpoint, data=data, error_handler=error_handler**kwargs
)

def close(self):
self._request("close", "/")
if self.auth_data_provider is not None:
Expand Down
3 changes: 2 additions & 1 deletion clients/client-python/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ llama-index==0.10.40
tenacity==8.3.0
cachetools==5.3.3
readerwriterlock==1.0.9
docker==7.1.0
docker==7.1.0
pyjwt[crypto]==2.8.0
Loading

0 comments on commit 213bcc9

Please sign in to comment.