Skip to content

Commit

Permalink
feat(embedded): aud claim and type for guest token (#18651)
Browse files Browse the repository at this point in the history
* add aud claim and type for guest token

* update test

* lint

* make jwt audience configurable

* lint

* Apply suggestions from code review

Co-authored-by: David Aaron Suddjian <[email protected]>

* verify aud

* add tests for aud and type claim

Co-authored-by: David Aaron Suddjian <[email protected]>
  • Loading branch information
Lily Kuang and suddjian authored Feb 14, 2022
1 parent 4001165 commit e6ea197
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
1 change: 1 addition & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,7 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument
GUEST_TOKEN_JWT_ALGO = "HS256"
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
GUEST_TOKEN_JWT_AUDIENCE = None

# A SQL dataset health check. Note if enabled it is strongly advised that the callable
# be memoized to aid with performance, i.e.,
Expand Down
9 changes: 8 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
GuestUser,
)
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType
from superset.utils.urls import get_url_host

if TYPE_CHECKING:
from superset.common.query_context import QueryContext
Expand Down Expand Up @@ -1308,6 +1309,7 @@ def create_guest_access_token(
secret = current_app.config["GUEST_TOKEN_JWT_SECRET"]
algo = current_app.config["GUEST_TOKEN_JWT_ALGO"]
exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"]
aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host()

# calculate expiration time
now = self._get_current_epoch_time()
Expand All @@ -1319,6 +1321,8 @@ def create_guest_access_token(
# standard jwt claims:
"iat": now, # issued at
"exp": exp, # expiration time
"aud": aud,
"type": "guest",
}
token = jwt.encode(claims, secret, algorithm=algo)
return token
Expand All @@ -1344,6 +1348,8 @@ def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]:
raise ValueError("Guest token does not contain a resources claim")
if token.get("rls_rules") is None:
raise ValueError("Guest token does not contain an rls_rules claim")
if token.get("type") != "guest":
raise ValueError("This is not a guest token.")
except Exception: # pylint: disable=broad-except
# The login manager will handle sending 401s.
# We don't need to send a special error message.
Expand All @@ -1366,7 +1372,8 @@ def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]:
"""
secret = current_app.config["GUEST_TOKEN_JWT_SECRET"]
algo = current_app.config["GUEST_TOKEN_JWT_ALGO"]
return jwt.decode(raw_token, secret, algorithms=[algo])
aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host()
return jwt.decode(raw_token, secret, algorithms=[algo], audience=aud)

@staticmethod
def is_guest_user(user: Optional[Any] = None) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion tests/integration_tests/security/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tests.integration_tests.base_tests import SupersetTestCase
from flask_wtf.csrf import generate_csrf
from superset.utils.urls import get_url_host


class TestSecurityCsrfApi(SupersetTestCase):
Expand Down Expand Up @@ -90,6 +91,8 @@ def test_post_guest_token_authorized(self):

self.assert200(response)
token = json.loads(response.data)["token"]
decoded_token = jwt.decode(token, self.app.config["GUEST_TOKEN_JWT_SECRET"])
decoded_token = jwt.decode(
token, self.app.config["GUEST_TOKEN_JWT_SECRET"], audience=get_url_host()
)
self.assertEqual(user, decoded_token["user"])
self.assertEqual(resource, decoded_token["resources"][0])
60 changes: 59 additions & 1 deletion tests/integration_tests/security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
get_example_default_schema,
)
from superset.utils.database import get_example_database
from superset.utils.urls import get_url_host
from superset.views.access_requests import AccessRequestsModelView

from .base_tests import SupersetTestCase
Expand Down Expand Up @@ -1177,17 +1178,20 @@ def test_create_guest_access_token(self, get_time_mock):
resources = [{"some": "resource"}]
rls = [{"dataset": 1, "clause": "access = 1"}]
token = security_manager.create_guest_access_token(user, resources, rls)

aud = get_url_host()
# unfortunately we cannot mock time in the jwt lib
decoded_token = jwt.decode(
token,
self.app.config["GUEST_TOKEN_JWT_SECRET"],
algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]],
audience=aud,
)

self.assertEqual(user, decoded_token["user"])
self.assertEqual(resources, decoded_token["resources"])
self.assertEqual(now, decoded_token["iat"])
self.assertEqual(aud, decoded_token["aud"])
self.assertEqual("guest", decoded_token["type"])
self.assertEqual(
now + (self.app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] * 1000),
decoded_token["exp"],
Expand Down Expand Up @@ -1241,3 +1245,57 @@ def test_get_guest_user_no_resource(self):
self.assertRaisesRegex(
ValueError, "Guest token does not contain a resources claim"
)

def test_get_guest_user_not_guest_type(self):
now = time.time()
user = {"username": "test_guest"}
resources = [{"some": "resource"}]
aud = get_url_host()

claims = {
"user": user,
"resources": resources,
"rls_rules": [],
# standard jwt claims:
"aud": aud,
"iat": now, # issued at
"type": "not_guest",
}
token = jwt.encode(
claims,
self.app.config["GUEST_TOKEN_JWT_SECRET"],
algorithm=self.app.config["GUEST_TOKEN_JWT_ALGO"],
)
fake_request = FakeRequest()
fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token
guest_user = security_manager.get_guest_user_from_request(fake_request)

self.assertIsNone(guest_user)
self.assertRaisesRegex(ValueError, "This is not a guest token.")

def test_get_guest_user_bad_audience(self):
now = time.time()
user = {"username": "test_guest"}
resources = [{"some": "resource"}]
aud = get_url_host()

claims = {
"user": user,
"resources": resources,
"rls_rules": [],
# standard jwt claims:
"aud": "bad_audience",
"iat": now, # issued at
"type": "guest",
}
token = jwt.encode(
claims,
self.app.config["GUEST_TOKEN_JWT_SECRET"],
algorithm=self.app.config["GUEST_TOKEN_JWT_ALGO"],
)
fake_request = FakeRequest()
fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token
guest_user = security_manager.get_guest_user_from_request(fake_request)

self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience")
self.assertIsNone(guest_user)

0 comments on commit e6ea197

Please sign in to comment.