Skip to content

Commit

Permalink
feat: allow passing SSL verify to session
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Dec 18, 2024
1 parent bac8b60 commit e6f93d3
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Changelog
Next
====

- Allow disabling certificate validation in GSheets (#???)

Version 1.3.3 - 2024-12-01
==========================

Expand Down
11 changes: 9 additions & 2 deletions src/shillelagh/adapters/api/gsheets/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import urllib.parse
from collections.abc import Iterator
from typing import Any, Optional, cast
from typing import Any, Optional, Union, cast

import dateutil.tz
from google.auth.transport.requests import AuthorizedSession
Expand Down Expand Up @@ -110,6 +110,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen
subject: Optional[str] = None,
catalog: Optional[dict[str, str]] = None,
app_default_credentials: bool = False,
session_verify: Optional[Union[bool, str]] = None,
):
super().__init__()
if catalog and uri in catalog:
Expand All @@ -123,6 +124,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen
subject,
app_default_credentials,
)
self.session_verify = session_verify

# Local data. When using DML we switch to the Google Sheets API,
# keeping a local copy of the spreadsheets data so that we can
Expand Down Expand Up @@ -200,11 +202,16 @@ def _set_metadata(self, uri: str) -> None:
_logger.warning("Could not determine sheet name!")

def _get_session(self) -> Session:
return cast(
session = cast(
Session,
AuthorizedSession(self.credentials) if self.credentials else Session(),
)

if self.session_verify is not None:
session.verify = self.session_verify

return session

def get_metadata(self) -> dict[str, Any]:
"""
Get metadata of a sheet.
Expand Down
7 changes: 5 additions & 2 deletions src/shillelagh/backends/apsw/dialects/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.parse
from datetime import timedelta
from operator import itemgetter
from typing import Any, Optional, cast
from typing import Any, Optional, Union, cast

import requests
from google.auth.transport.requests import AuthorizedSession
Expand Down Expand Up @@ -59,7 +59,7 @@ def extract_query(url: URL) -> QueryType:
return cast(QueryType, parameters)


class APSWGSheetsDialect(APSWDialect):
class APSWGSheetsDialect(APSWDialect): # pylint: disable=too-many-instance-attributes
"""
Drop-in replacement for gsheetsdb.
Expand All @@ -86,6 +86,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen
catalog: Optional[dict[str, str]] = None,
list_all_sheets: bool = False,
app_default_credentials: bool = False,
session_verify: Optional[Union[bool, str]] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -97,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen
self.catalog = catalog or {}
self.list_all_sheets = list_all_sheets
self.app_default_credentials = app_default_credentials
self.session_verify = session_verify

def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]:
adapter_kwargs: dict[str, Any] = {
Expand All @@ -106,6 +108,7 @@ def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]:
"subject": self.subject,
"catalog": self.catalog,
"app_default_credentials": self.app_default_credentials,
"session_verify": self.session_verify,
}
# parameters can be overridden via the query in the URL
adapter_kwargs.update(extract_query(url))
Expand Down
32 changes: 30 additions & 2 deletions tests/adapters/api/gsheets/adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import requests
import requests_mock
from pytest_mock import MockerFixture
from sqlalchemy import create_engine, text

from shillelagh.adapters.api.gsheets.adapter import GSheetsAPI
from shillelagh.backends.apsw.db import connect
Expand Down Expand Up @@ -191,7 +192,7 @@ def test_credentials() -> None:
mock.call("BEGIN IMMEDIATE"),
mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None),
mock.call(
"CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', '+9oGc2VjcmV02gNYWFgw', '+hB1c2VyQGV4YW1wbGUuY29t', 'Tg==', 'Rg==')",
"CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', '+9oGc2VjcmV02gNYWFgw', '+hB1c2VyQGV4YW1wbGUuY29t', 'Tg==', 'Rg==', 'Tg==')",
),
mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None),
],
Expand Down Expand Up @@ -224,7 +225,7 @@ def test_credentials() -> None:
mock.call("BEGIN IMMEDIATE"),
mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None),
mock.call(
"CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'VA==')",
"CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'VA==', 'Tg==')",
),
mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None),
],
Expand Down Expand Up @@ -2360,3 +2361,30 @@ def test_get_cost(mocker: MockerFixture) -> None:
)
== 3022
)


def test_session_verify(
mocker: MockerFixture,
simple_sheet_adapter: requests_mock.Adapter,
) -> None:
"""
Test setting ``verify`` in the session.
"""
session = requests.Session()
session.mount("https://", simple_sheet_adapter)
mocker.patch(
"shillelagh.adapters.api.gsheets.adapter.AuthorizedSession",
return_value=session,
)
mocker.patch(
"shillelagh.adapters.api.gsheets.adapter.get_credentials",
return_value="SECRET",
)

engine = create_engine("gsheets://", session_verify=False)
connection = engine.connect()

sql = '''SELECT * FROM "https://docs.google.com/spreadsheets/d/1/edit#gid=0"'''
connection.execute(text(sql))

assert session.verify is False
26 changes: 26 additions & 0 deletions tests/backends/apsw/dialects/gsheets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_gsheets_dialect() -> None:
"subject": None,
"catalog": {},
"app_default_credentials": False,
"session_verify": None,
},
},
"safe": True,
Expand All @@ -63,6 +64,7 @@ def test_gsheets_dialect() -> None:
"subject": "[email protected]",
"catalog": {},
"app_default_credentials": False,
"session_verify": None,
},
},
"safe": True,
Expand All @@ -88,6 +90,7 @@ def test_gsheets_dialect() -> None:
"subject": "[email protected]",
"catalog": {"public_sheet": "https://example.com/"},
"app_default_credentials": False,
"session_verify": None,
},
},
"safe": True,
Expand All @@ -111,6 +114,7 @@ def test_gsheets_dialect() -> None:
"subject": None,
"catalog": {},
"app_default_credentials": True,
"session_verify": None,
},
},
"safe": True,
Expand All @@ -121,6 +125,28 @@ def test_gsheets_dialect() -> None:
mock_dbapi_connection = mock.MagicMock()
assert dialect.get_schema_names(mock_dbapi_connection) == []

dialect = APSWGSheetsDialect(session_verify=False)
assert dialect.create_connect_args(make_url("gsheets://")) == (
(),
{
"path": ":memory:",
"adapters": ["gsheetsapi"],
"adapter_kwargs": {
"gsheetsapi": {
"access_token": None,
"service_account_file": None,
"service_account_info": None,
"subject": None,
"catalog": {},
"app_default_credentials": False,
"session_verify": False,
},
},
"safe": True,
"isolation_level": None,
},
)


def test_get_table_names(mocker: MockerFixture, requests_mock: Mocker) -> None:
"""
Expand Down

0 comments on commit e6f93d3

Please sign in to comment.