Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#3755] improvement(client-python): Support OAuth2TokenProvider for Python client #4011

Merged
merged 13 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clients/client-python/gravitino/auth/auth_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@
class AuthConstants:
HTTP_HEADER_AUTHORIZATION: str = "Authorization"

AUTHORIZATION_BEARER_HEADER: str = "Bearer "

AUTHORIZATION_BASIC_HEADER: str = "Basic "
133 changes: 133 additions & 0 deletions clients/client-python/gravitino/auth/default_oauth2_token_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

import time
import json
import base64
from typing import Optional
from gravitino.auth.oauth2_token_provider import OAuth2TokenProvider
from gravitino.dto.responses.oauth2_token_response import OAuth2TokenResponse
from gravitino.dto.requests.oauth2_client_credential_request import (
OAuth2ClientCredentialRequest,
)
from gravitino.exceptions.base import GravitinoRuntimeException

CLIENT_CREDENTIALS = "client_credentials"
CREDENTIAL_SPLITTER = ":"
TOKEN_SPLITTER = "."
JWT_EXPIRE = "exp"


class DefaultOAuth2TokenProvider(OAuth2TokenProvider):
"""This class is the default implement of OAuth2TokenProvider."""

_credential: Optional[str]
_scope: Optional[str]
_path: Optional[str]
_token: Optional[str]

def __init__(
self,
uri: str = None,
credential: str = None,
scope: str = None,
path: str = None,
):
super().__init__(uri)

self._credential = credential
self._scope = scope
self._path = path

self.validate()

self._token = self._fetch_token()

def validate(self):
assert (
self._credential and self._credential.strip()
), "OAuth2TokenProvider must set credential"
assert self._scope and self._scope.strip(), "OAuth2TokenProvider must set scope"
assert self._path and self._path.strip(), "OAuth2TokenProvider must set path"

def _get_access_token(self) -> Optional[str]:

expires = self._expires_at_millis()

if expires is None:
return None

if expires > time.time() * 1000:
return self._token

self._token = self._fetch_token()
return self._token

def _parse_credential(self):
assert self._credential is not None, "Invalid credential: None"

credential_info = self._credential.split(CREDENTIAL_SPLITTER, maxsplit=1)
client_id = None
client_secret = None

if len(credential_info) == 2:
client_id, client_secret = credential_info
elif len(credential_info) == 1:
client_secret = credential_info[0]
else:
raise GravitinoRuntimeException(f"Invalid credential: {self._credential}")

return client_id, client_secret

def _fetch_token(self) -> str:

client_id, client_secret = self._parse_credential()

client_credential_request = OAuth2ClientCredentialRequest(
grant_type=CLIENT_CREDENTIALS,
client_id=client_id,
client_secret=client_secret,
scope=self._scope,
)

resp = self._client.post_form(
self._path, data=client_credential_request.to_dict()
)
oauth2_resp = OAuth2TokenResponse.from_json(resp.body, infer_missing=True)
oauth2_resp.validate()

return oauth2_resp.access_token()

def _expires_at_millis(self) -> int:
if self._token is None:
return None

parts = self._token.split(TOKEN_SPLITTER)

if len(parts) != 3:
return None

jwt = json.loads(
base64.b64decode(parts[1] + "=" * (-len(parts[1]) % 4)).decode("utf-8")
)

if JWT_EXPIRE not in jwt or not isinstance(jwt[JWT_EXPIRE], int):
return None

return jwt[JWT_EXPIRE] * 1000
75 changes: 75 additions & 0 deletions clients/client-python/gravitino/auth/oauth2_token_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

from abc import abstractmethod
from typing import Optional

from gravitino.utils.http_client import HTTPClient
from gravitino.auth.auth_data_provider import AuthDataProvider
from gravitino.auth.auth_constants import AuthConstants


class OAuth2TokenProvider(AuthDataProvider):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that do not have docs, please add some docs for the new classes which have docs in Java.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"""OAuth2TokenProvider will request the access token from the authorization server and then provide
the access token for every request.
"""

# The HTTP client used to request the access token from the authorization server.
_client: HTTPClient

def __init__(self, uri: str):
self._client = HTTPClient(uri)

def has_token_data(self) -> bool:
"""Judge whether AuthDataProvider can provide token data.

Returns:
true if the AuthDataProvider can provide token data otherwise false.
"""
return True

def get_token_data(self) -> Optional[bytes]:
"""Acquire the data of token for authentication. The client will set the token data as HTTP header
Authorization directly. So the return value should ensure token data contain the token header
(eg: Bearer, Basic) if necessary.

Returns:
the token data is used for authentication.
"""
access_token = self._get_access_token()

if access_token is None:
return None

return (AuthConstants.AUTHORIZATION_BEARER_HEADER + access_token).encode(
"utf-8"
)

def close(self):
"""Closes the OAuth2TokenProvider and releases any underlying resources."""
if self._client is not None:
self._client.close()

@abstractmethod
def _get_access_token(self) -> Optional[str]:
"""Get the access token from the authorization server."""

@abstractmethod
def validate(self):
"""Validate the OAuth2TokenProvider"""
4 changes: 2 additions & 2 deletions clients/client-python/gravitino/auth/simple_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import base64
import os

from .auth_constants import AuthConstants
from .auth_data_provider import AuthDataProvider
from gravitino.auth.auth_constants import AuthConstants
from gravitino.auth.auth_data_provider import AuthDataProvider


class SimpleAuthProvider(AuthDataProvider):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

from typing import Optional
from dataclasses import dataclass


@dataclass
class OAuth2ClientCredentialRequest:

grant_type: str
client_id: Optional[str]
client_secret: str
scope: str

def to_dict(self, **kwarg):
return {k: v for k, v in self.__dict__.items() if v is not None}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""

from typing import Optional
from dataclasses import dataclass, field
from dataclasses_json import config

from gravitino.dto.responses.base_response import BaseResponse
from gravitino.auth.auth_constants import AuthConstants


@dataclass
class OAuth2TokenResponse(BaseResponse):

_access_token: str = field(metadata=config(field_name="access_token"))
_issue_token_type: Optional[str] = field(
metadata=config(field_name="issued_token_type")
)
_token_type: str = field(metadata=config(field_name="token_type"))
_expires_in: int = field(metadata=config(field_name="expires_in"))
_scope: str = field(metadata=config(field_name="scope"))
_refresh_token: Optional[str] = field(metadata=config(field_name="refresh_token"))

def validate(self):
"""Validates the response.

Raise:
IllegalArgumentException If the response is invalid, this exception is thrown.
"""
super().validate()

assert self._access_token is not None, "Invalid access token: None"
assert (
AuthConstants.AUTHORIZATION_BEARER_HEADER.strip().lower()
== self._token_type.lower()
), f'Unsupported token type: {self._token_type} (must be "bearer")'

def access_token(self) -> str:
return self._access_token
36 changes: 27 additions & 9 deletions clients/client-python/gravitino/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def json(self):


class HTTPClient:

FORMDATA_HEADER = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/vnd.gravitino.v1+json",
}

JSON_HEADER = {
"Content-Type": "application/json",
"Accept": "application/vnd.gravitino.v1+json",
}

def __init__(
self,
host,
Expand Down Expand Up @@ -139,30 +150,32 @@ def _make_request(self, opener, request, timeout=None) -> Tuple[bool, Response]:

return (False, err_resp)

# pylint: disable=too-many-locals
def _request(
self,
method,
endpoint,
params=None,
json=None,
data=None,
headers=None,
timeout=None,
error_handler: ErrorHandler = None,
):
method = method.upper()
request_data = None

if headers:
self._update_headers(headers)
if data:
request_data = urlencode(data.to_dict()).encode()
self._update_headers(self.FORMDATA_HEADER)
else:
headers = {
"Content-Type": "application/json",
"Accept": "application/vnd.gravitino.v1+json",
}
self._update_headers(headers)
if json:
request_data = json.to_json().encode("utf-8")

if json:
request_data = json.to_json().encode("utf-8")
self._update_headers(self.JSON_HEADER)

if headers:
self._update_headers(headers)

opener = build_opener()
request = Request(self._build_url(endpoint, params), data=request_data)
Expand Down Expand Up @@ -213,6 +226,11 @@ def put(self, endpoint, json=None, error_handler=None, **kwargs):
"put", endpoint, json=json, error_handler=error_handler, **kwargs
)

def post_form(self, endpoint, data=None, error_handler=None, **kwargs):
return self._request(
"post", endpoint, data=data, error_handler=error_handler**kwargs
)

def close(self):
self._request("close", "/")
if self.auth_data_provider is not None:
Expand Down
3 changes: 2 additions & 1 deletion clients/client-python/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ llama-index==0.10.40
tenacity==8.3.0
cachetools==5.3.3
readerwriterlock==1.0.9
docker==7.1.0
docker==7.1.0
pyjwt[crypto]==2.8.0
Loading
Loading