forked from apache/gravitino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_oauth2_token_provider.py
79 lines (62 loc) · 2.54 KB
/
test_oauth2_token_provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Copyright 2024 Datastrato Pvt Ltd.
This software is licensed under the Apache License version 2.
"""
import unittest
from unittest.mock import patch
from gravitino.auth.auth_constants import AuthConstants
from gravitino.auth.default_oauth_to_token_provider import DefaultOAuth2TokenProvider
from tests.unittests.auth import mock_base
OAUTH_PORT = 1082
class TestOAuth2TokenProvider(unittest.TestCase):
def test_provider_init_exception(self):
with self.assertRaises(AssertionError):
_ = DefaultOAuth2TokenProvider(uri="test")
with self.assertRaises(AssertionError):
_ = DefaultOAuth2TokenProvider(uri="test", credential="xx")
with self.assertRaises(AssertionError):
_ = DefaultOAuth2TokenProvider(uri="test", credential="xx", scope="test")
# TODO
# Error Test
@patch(
"gravitino.utils.http_client.HTTPClient.post_form",
return_value=mock_base.mock_authentication_with_error_authentication_type(),
)
def test_authentication_with_error_authentication_type(self, *mock_methods):
with self.assertRaises(AssertionError):
_ = DefaultOAuth2TokenProvider(
uri=f"http://127.0.0.1:{OAUTH_PORT}",
credential="yy:xx",
path="oauth/token",
scope="test",
)
@patch(
"gravitino.utils.http_client.HTTPClient.post_form",
return_value=mock_base.mock_authentication_with_non_jwt(),
)
def test_authentication_with_non_jwt(self, *mock_methods):
token_provider = DefaultOAuth2TokenProvider(
uri=f"http://127.0.0.1:{OAUTH_PORT}",
credential="yy:xx",
path="oauth/token",
scope="test",
)
self.assertTrue(token_provider.has_token_data())
self.assertIsNone(token_provider.get_token_data())
@patch(
"gravitino.utils.http_client.HTTPClient.post_form",
side_effect=mock_base.mock_authentication_with_jwt(),
)
def test_authentication_with_jwt(self, *mock_methods):
old_access_token, new_access_token = mock_base.mock_old_new_jwt()
token_provider = DefaultOAuth2TokenProvider(
uri=f"http://127.0.0.1:{OAUTH_PORT}",
credential="yy:xx",
path="oauth/token",
scope="test",
)
self.assertNotEqual(old_access_token, new_access_token)
self.assertEqual(
token_provider.get_token_data().decode("utf-8"),
AuthConstants.AUTHORIZATION_BEARER_HEADER + new_access_token,
)