-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7e1c4e9
commit fdadf84
Showing
6 changed files
with
333 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,15 +18,22 @@ | |
# pylint: disable=import-outside-toplevel, invalid-name, line-too-long | ||
|
||
import json | ||
from typing import TYPE_CHECKING | ||
from urllib.parse import parse_qs, urlparse | ||
|
||
import jwt | ||
import pandas as pd | ||
import pytest | ||
from pytest_mock import MockFixture | ||
from sqlalchemy.engine.url import make_url | ||
|
||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType | ||
from superset.exceptions import SupersetException | ||
from superset.sql_parse import Table | ||
|
||
if TYPE_CHECKING: | ||
from superset.db_engine_specs.base import OAuth2State | ||
|
||
|
||
class ProgrammingError(Exception): | ||
""" | ||
|
@@ -399,3 +406,223 @@ def test_upload_existing(mocker: MockFixture) -> None: | |
mocker.call().json(), | ||
] | ||
) | ||
|
||
|
||
def test_get_url_for_impersonation_username(mocker: MockFixture) -> None: | ||
""" | ||
Test passing a username to `get_url_for_impersonation`. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
user = mocker.MagicMock() | ||
user.email = "[email protected]" | ||
mocker.patch( | ||
"superset.db_engine_specs.gsheets.security_manager.find_user", | ||
return_value=user, | ||
) | ||
|
||
assert GSheetsEngineSpec.get_url_for_impersonation( | ||
url=make_url("gsheets://"), | ||
impersonate_user=True, | ||
username="alice", | ||
access_token=None, | ||
) == make_url("gsheets://?subject=alice%40example.org") | ||
|
||
|
||
def test_get_url_for_impersonation_access_token() -> None: | ||
""" | ||
Test passing an access token to `get_url_for_impersonation`. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
assert GSheetsEngineSpec.get_url_for_impersonation( | ||
url=make_url("gsheets://"), | ||
impersonate_user=True, | ||
username=None, | ||
access_token="access-token", | ||
) == make_url("gsheets://?access_token=access-token") | ||
|
||
|
||
def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None: | ||
""" | ||
Test `is_oauth2_enabled` when OAuth2 is not configured. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
mocker.patch( | ||
"superset.db_engine_specs.gsheets.current_app.config", | ||
new={"DATABASE_OAUTH2_CREDENTIALS": {}}, | ||
) | ||
|
||
assert GSheetsEngineSpec.is_oauth2_enabled() is False | ||
|
||
|
||
def test_is_oauth2_enabled_config(mocker: MockFixture) -> None: | ||
""" | ||
Test `is_oauth2_enabled` when OAuth2 is configured. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
mocker.patch( | ||
"superset.db_engine_specs.gsheets.current_app.config", | ||
new={ | ||
"DATABASE_OAUTH2_CREDENTIALS": { | ||
"Google Sheets": { | ||
"CLIENT_ID": "XXX.apps.googleusercontent.com", | ||
"CLIENT_SECRET": "GOCSPX-YYY", | ||
}, | ||
} | ||
}, | ||
) | ||
|
||
assert GSheetsEngineSpec.is_oauth2_enabled() is True | ||
|
||
|
||
def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None: | ||
""" | ||
Test `get_oauth2_authorization_uri`. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
mocker.patch( | ||
"superset.db_engine_specs.gsheets.current_app.config", | ||
new={ | ||
"DATABASE_OAUTH2_CREDENTIALS": { | ||
"Google Sheets": { | ||
"CLIENT_ID": "XXX.apps.googleusercontent.com", | ||
"CLIENT_SECRET": "GOCSPX-YYY", | ||
}, | ||
}, | ||
"SECRET_KEY": "not-a-secret", | ||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", | ||
}, | ||
) | ||
|
||
state: OAuth2State = { | ||
"database_id": 1, | ||
"user_id": 1, | ||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", | ||
"tab_id": "1234", | ||
} | ||
|
||
url = GSheetsEngineSpec.get_oauth2_authorization_uri(state) | ||
parsed = urlparse(url) | ||
assert parsed.netloc == "accounts.google.com" | ||
assert parsed.path == "/o/oauth2/v2/auth" | ||
|
||
query = parse_qs(parsed.query) | ||
assert query["scope"][0] == ( | ||
"https://www.googleapis.com/auth/drive.readonly " | ||
"https://www.googleapis.com/auth/spreadsheets " | ||
"https://spreadsheets.google.com/feeds" | ||
) | ||
encoded_state = query["state"][0].replace("%2E", ".") | ||
assert jwt.decode(encoded_state, "not-a-secret", ["HS256"]) == state | ||
|
||
|
||
def test_get_oauth2_token(mocker: MockFixture) -> None: | ||
""" | ||
Test `get_oauth2_token`. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
http = mocker.patch("superset.db_engine_specs.gsheets.http") | ||
http.request().data.decode.return_value = json.dumps( | ||
{ | ||
"access_token": "access-token", | ||
"expires_in": 3600, | ||
"scope": "scope", | ||
"token_type": "Bearer", | ||
"refresh_token": "refresh-token", | ||
} | ||
) | ||
|
||
mocker.patch( | ||
"superset.db_engine_specs.gsheets.current_app.config", | ||
new={ | ||
"DATABASE_OAUTH2_CREDENTIALS": { | ||
"Google Sheets": { | ||
"CLIENT_ID": "XXX.apps.googleusercontent.com", | ||
"CLIENT_SECRET": "GOCSPX-YYY", | ||
}, | ||
}, | ||
"SECRET_KEY": "not-a-secret", | ||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", | ||
}, | ||
) | ||
|
||
state: OAuth2State = { | ||
"database_id": 1, | ||
"user_id": 1, | ||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", | ||
"tab_id": "1234", | ||
} | ||
|
||
assert GSheetsEngineSpec.get_oauth2_token("code", state) == { | ||
"access_token": "access-token", | ||
"expires_in": 3600, | ||
"scope": "scope", | ||
"token_type": "Bearer", | ||
"refresh_token": "refresh-token", | ||
} | ||
http.request.assert_called_with( | ||
"POST", | ||
"https://oauth2.googleapis.com/token", | ||
fields={ | ||
"code": "code", | ||
"client_id": "XXX.apps.googleusercontent.com", | ||
"client_secret": "GOCSPX-YYY", | ||
"redirect_uri": "http://localhost:8088/api/v1/oauth2/", | ||
"grant_type": "authorization_code", | ||
}, | ||
) | ||
|
||
|
||
def test_get_oauth2_fresh_token(mocker: MockFixture) -> None: | ||
""" | ||
Test `get_oauth2_token`. | ||
""" | ||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec | ||
|
||
http = mocker.patch("superset.db_engine_specs.gsheets.http") | ||
http.request().data.decode.return_value = json.dumps( | ||
{ | ||
"access_token": "access-token", | ||
"expires_in": 3600, | ||
"scope": "scope", | ||
"token_type": "Bearer", | ||
"refresh_token": "refresh-token", | ||
} | ||
) | ||
|
||
mocker.patch( | ||
"superset.db_engine_specs.gsheets.current_app.config", | ||
new={ | ||
"DATABASE_OAUTH2_CREDENTIALS": { | ||
"Google Sheets": { | ||
"CLIENT_ID": "XXX.apps.googleusercontent.com", | ||
"CLIENT_SECRET": "GOCSPX-YYY", | ||
}, | ||
}, | ||
"SECRET_KEY": "not-a-secret", | ||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", | ||
}, | ||
) | ||
|
||
assert GSheetsEngineSpec.get_oauth2_fresh_token("refresh-token") == { | ||
"access_token": "access-token", | ||
"expires_in": 3600, | ||
"scope": "scope", | ||
"token_type": "Bearer", | ||
"refresh_token": "refresh-token", | ||
} | ||
http.request.assert_called_with( | ||
"POST", | ||
"https://oauth2.googleapis.com/token", | ||
fields={ | ||
"client_id": "XXX.apps.googleusercontent.com", | ||
"client_secret": "GOCSPX-YYY", | ||
"refresh_token": "refresh-token", | ||
"grant_type": "refresh_token", | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# pylint: disable=invalid-name, disallowed-name | ||
|
||
from datetime import datetime | ||
|
||
from freezegun import freeze_time | ||
from pytest_mock import MockerFixture | ||
|
||
from superset.utils.oauth2 import get_oauth2_access_token | ||
|
||
|
||
def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None: | ||
""" | ||
Test `get_oauth2_access_token` when there's no token. | ||
""" | ||
db = mocker.patch("superset.utils.oauth2.db") | ||
db_engine_spec = mocker.MagicMock() | ||
db.session.query().filter_by().one_or_none.return_value = None | ||
|
||
assert get_oauth2_access_token(1, 1, db_engine_spec) is None | ||
|
||
|
||
def test_get_oauth2_access_token_base_token_valid(mocker: MockerFixture) -> None: | ||
""" | ||
Test `get_oauth2_access_token` when the token is valid. | ||
""" | ||
db = mocker.patch("superset.utils.oauth2.db") | ||
db_engine_spec = mocker.MagicMock() | ||
token = mocker.MagicMock() | ||
token.access_token = "access-token" | ||
token.access_token_expiration = datetime(2024, 1, 2) | ||
db.session.query().filter_by().one_or_none.return_value = token | ||
|
||
with freeze_time("2024-01-01"): | ||
assert get_oauth2_access_token(1, 1, db_engine_spec) == "access-token" | ||
|
||
|
||
def test_get_oauth2_access_token_base_refresh(mocker: MockerFixture) -> None: | ||
""" | ||
Test `get_oauth2_access_token` when the token needs to be refreshed. | ||
""" | ||
db = mocker.patch("superset.utils.oauth2.db") | ||
db_engine_spec = mocker.MagicMock() | ||
db_engine_spec.get_oauth2_fresh_token.return_value = { | ||
"access_token": "new-token", | ||
"expires_in": 3600, | ||
} | ||
token = mocker.MagicMock() | ||
token.access_token = "access-token" | ||
token.access_token_expiration = datetime(2024, 1, 1) | ||
token.refresh_token = "refresh-token" | ||
db.session.query().filter_by().one_or_none.return_value = token | ||
|
||
with freeze_time("2024-01-02"): | ||
assert get_oauth2_access_token(1, 1, db_engine_spec) == "new-token" | ||
|
||
# check that token was updated | ||
assert token.access_token == "new-token" | ||
assert token.access_token_expiration == datetime(2024, 1, 2, 1) | ||
db.session.add.assert_called_with(token) | ||
|
||
|
||
def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None: | ||
""" | ||
Test `get_oauth2_access_token` when token is expired and there's no refresh. | ||
""" | ||
db = mocker.patch("superset.utils.oauth2.db") | ||
db_engine_spec = mocker.MagicMock() | ||
token = mocker.MagicMock() | ||
token.access_token = "access-token" | ||
token.access_token_expiration = datetime(2024, 1, 1) | ||
token.refresh_token = None | ||
db.session.query().filter_by().one_or_none.return_value = token | ||
|
||
with freeze_time("2024-01-02"): | ||
assert get_oauth2_access_token(1, 1, db_engine_spec) is None | ||
|
||
# check that token was deleted | ||
db.session.delete.assert_called_with(token) |