Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 26, 2024
1 parent 7e1c4e9 commit fdadf84
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 1 deletion.
3 changes: 3 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
# errors. This is particularly important for database OAuth2, see SIP-85.
raise ex
except Exception as ex: # pylint: disable=broad-except
# TODO (betodealmeida): review exception handling while querying the external
# database. Ideally we'd expect and handle external database error, but everything else / the
# default should be to let things bubble up.
df = pd.DataFrame()
status = QueryStatus.FAILED
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def upgrade():
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"idx_user_id_database_id",
"database_user_oauth2_tokens",
["user_id", "database_id"],
)


def downgrade():
op.drop_index("idx_user_id_database_id", table_name="database_user_oauth2_tokens")
op.drop_table("database_user_oauth2_tokens")
1 change: 1 addition & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,7 @@ class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
"""

__tablename__ = "database_user_oauth2_tokens"
__table_args__ = (sqla.Index("idx_user_id_database_id", "user_id", "database_id"),)

id = Column(Integer, primary_key=True)

Expand Down
2 changes: 1 addition & 1 deletion superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_oauth2_access_token(
if token is None:
return None

if token.access_token and token.access_token_expiration < datetime.now():
if token.access_token and datetime.now() < token.access_token_expiration:
return token.access_token

if token.refresh_token:
Expand Down
227 changes: 227 additions & 0 deletions tests/unit_tests/db_engine_specs/test_gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,22 @@
# pylint: disable=import-outside-toplevel, invalid-name, line-too-long

import json
from typing import TYPE_CHECKING
from urllib.parse import parse_qs, urlparse

import jwt
import pandas as pd
import pytest
from pytest_mock import MockFixture
from sqlalchemy.engine.url import make_url

from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql_parse import Table

if TYPE_CHECKING:
from superset.db_engine_specs.base import OAuth2State


class ProgrammingError(Exception):
"""
Expand Down Expand Up @@ -399,3 +406,223 @@ def test_upload_existing(mocker: MockFixture) -> None:
mocker.call().json(),
]
)


def test_get_url_for_impersonation_username(mocker: MockFixture) -> None:
"""
Test passing a username to `get_url_for_impersonation`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

user = mocker.MagicMock()
user.email = "[email protected]"
mocker.patch(
"superset.db_engine_specs.gsheets.security_manager.find_user",
return_value=user,
)

assert GSheetsEngineSpec.get_url_for_impersonation(
url=make_url("gsheets://"),
impersonate_user=True,
username="alice",
access_token=None,
) == make_url("gsheets://?subject=alice%40example.org")


def test_get_url_for_impersonation_access_token() -> None:
"""
Test passing an access token to `get_url_for_impersonation`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

assert GSheetsEngineSpec.get_url_for_impersonation(
url=make_url("gsheets://"),
impersonate_user=True,
username=None,
access_token="access-token",
) == make_url("gsheets://?access_token=access-token")


def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={"DATABASE_OAUTH2_CREDENTIALS": {}},
)

assert GSheetsEngineSpec.is_oauth2_enabled() is False


def test_is_oauth2_enabled_config(mocker: MockFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
}
},
)

assert GSheetsEngineSpec.is_oauth2_enabled() is True


def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None:
"""
Test `get_oauth2_authorization_uri`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
},
"SECRET_KEY": "not-a-secret",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)

state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "1234",
}

url = GSheetsEngineSpec.get_oauth2_authorization_uri(state)
parsed = urlparse(url)
assert parsed.netloc == "accounts.google.com"
assert parsed.path == "/o/oauth2/v2/auth"

query = parse_qs(parsed.query)
assert query["scope"][0] == (
"https://www.googleapis.com/auth/drive.readonly "
"https://www.googleapis.com/auth/spreadsheets "
"https://spreadsheets.google.com/feeds"
)
encoded_state = query["state"][0].replace("%2E", ".")
assert jwt.decode(encoded_state, "not-a-secret", ["HS256"]) == state


def test_get_oauth2_token(mocker: MockFixture) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

http = mocker.patch("superset.db_engine_specs.gsheets.http")
http.request().data.decode.return_value = json.dumps(
{
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
)

mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
},
"SECRET_KEY": "not-a-secret",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)

state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "1234",
}

assert GSheetsEngineSpec.get_oauth2_token("code", state) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
http.request.assert_called_with(
"POST",
"https://oauth2.googleapis.com/token",
fields={
"code": "code",
"client_id": "XXX.apps.googleusercontent.com",
"client_secret": "GOCSPX-YYY",
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"grant_type": "authorization_code",
},
)


def test_get_oauth2_fresh_token(mocker: MockFixture) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec

http = mocker.patch("superset.db_engine_specs.gsheets.http")
http.request().data.decode.return_value = json.dumps(
{
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
)

mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
},
"SECRET_KEY": "not-a-secret",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)

assert GSheetsEngineSpec.get_oauth2_fresh_token("refresh-token") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
http.request.assert_called_with(
"POST",
"https://oauth2.googleapis.com/token",
fields={
"client_id": "XXX.apps.googleusercontent.com",
"client_secret": "GOCSPX-YYY",
"refresh_token": "refresh-token",
"grant_type": "refresh_token",
},
)
95 changes: 95 additions & 0 deletions tests/unit_tests/utils/oauth2_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, disallowed-name

from datetime import datetime

from freezegun import freeze_time
from pytest_mock import MockerFixture

from superset.utils.oauth2 import get_oauth2_access_token


def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when there's no token.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
db.session.query().filter_by().one_or_none.return_value = None

assert get_oauth2_access_token(1, 1, db_engine_spec) is None


def test_get_oauth2_access_token_base_token_valid(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when the token is valid.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
token = mocker.MagicMock()
token.access_token = "access-token"
token.access_token_expiration = datetime(2024, 1, 2)
db.session.query().filter_by().one_or_none.return_value = token

with freeze_time("2024-01-01"):
assert get_oauth2_access_token(1, 1, db_engine_spec) == "access-token"


def test_get_oauth2_access_token_base_refresh(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when the token needs to be refreshed.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"access_token": "new-token",
"expires_in": 3600,
}
token = mocker.MagicMock()
token.access_token = "access-token"
token.access_token_expiration = datetime(2024, 1, 1)
token.refresh_token = "refresh-token"
db.session.query().filter_by().one_or_none.return_value = token

with freeze_time("2024-01-02"):
assert get_oauth2_access_token(1, 1, db_engine_spec) == "new-token"

# check that token was updated
assert token.access_token == "new-token"
assert token.access_token_expiration == datetime(2024, 1, 2, 1)
db.session.add.assert_called_with(token)


def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when token is expired and there's no refresh.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
token = mocker.MagicMock()
token.access_token = "access-token"
token.access_token_expiration = datetime(2024, 1, 1)
token.refresh_token = None
db.session.query().filter_by().one_or_none.return_value = token

with freeze_time("2024-01-02"):
assert get_oauth2_access_token(1, 1, db_engine_spec) is None

# check that token was deleted
db.session.delete.assert_called_with(token)

0 comments on commit fdadf84

Please sign in to comment.