diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 92513ebc..0cef6900 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,8 @@ Changelog Next ==== +- Allow disabling certificate validation in GSheets (#???) + Version 1.3.3 - 2024-12-01 ========================== diff --git a/src/shillelagh/adapters/api/gsheets/adapter.py b/src/shillelagh/adapters/api/gsheets/adapter.py index 855a9005..48c4c9d6 100644 --- a/src/shillelagh/adapters/api/gsheets/adapter.py +++ b/src/shillelagh/adapters/api/gsheets/adapter.py @@ -8,7 +8,7 @@ import logging import urllib.parse from collections.abc import Iterator -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast import dateutil.tz from google.auth.transport.requests import AuthorizedSession @@ -110,6 +110,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen subject: Optional[str] = None, catalog: Optional[dict[str, str]] = None, app_default_credentials: bool = False, + session_verify: Optional[Union[bool, str]] = None, ): super().__init__() if catalog and uri in catalog: @@ -123,6 +124,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen subject, app_default_credentials, ) + self.session_verify = session_verify # Local data. When using DML we switch to the Google Sheets API, # keeping a local copy of the spreadsheets data so that we can @@ -200,11 +202,16 @@ def _set_metadata(self, uri: str) -> None: _logger.warning("Could not determine sheet name!") def _get_session(self) -> Session: - return cast( + session = cast( Session, AuthorizedSession(self.credentials) if self.credentials else Session(), ) + if self.session_verify is not None: + session.verify = self.session_verify + + return session + def get_metadata(self) -> dict[str, Any]: """ Get metadata of a sheet. diff --git a/src/shillelagh/backends/apsw/dialects/gsheets.py b/src/shillelagh/backends/apsw/dialects/gsheets.py index 32bfb12f..d46a2335 100644 --- a/src/shillelagh/backends/apsw/dialects/gsheets.py +++ b/src/shillelagh/backends/apsw/dialects/gsheets.py @@ -9,7 +9,7 @@ import urllib.parse from datetime import timedelta from operator import itemgetter -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast import requests from google.auth.transport.requests import AuthorizedSession @@ -59,7 +59,7 @@ def extract_query(url: URL) -> QueryType: return cast(QueryType, parameters) -class APSWGSheetsDialect(APSWDialect): +class APSWGSheetsDialect(APSWDialect): # pylint: disable=too-many-instance-attributes """ Drop-in replacement for gsheetsdb. @@ -86,6 +86,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen catalog: Optional[dict[str, str]] = None, list_all_sheets: bool = False, app_default_credentials: bool = False, + session_verify: Optional[Union[bool, str]] = None, **kwargs: Any, ): super().__init__(**kwargs) @@ -97,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen self.catalog = catalog or {} self.list_all_sheets = list_all_sheets self.app_default_credentials = app_default_credentials + self.session_verify = session_verify def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]: adapter_kwargs: dict[str, Any] = { @@ -106,6 +108,7 @@ def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]: "subject": self.subject, "catalog": self.catalog, "app_default_credentials": self.app_default_credentials, + "session_verify": self.session_verify, } # parameters can be overridden via the query in the URL adapter_kwargs.update(extract_query(url)) diff --git a/tests/adapters/api/gsheets/adapter_test.py b/tests/adapters/api/gsheets/adapter_test.py index 815a896c..b02ad8a4 100644 --- a/tests/adapters/api/gsheets/adapter_test.py +++ b/tests/adapters/api/gsheets/adapter_test.py @@ -16,6 +16,7 @@ import requests import requests_mock from pytest_mock import MockerFixture +from sqlalchemy import create_engine, text from shillelagh.adapters.api.gsheets.adapter import GSheetsAPI from shillelagh.backends.apsw.db import connect @@ -191,7 +192,7 @@ def test_credentials() -> None: mock.call("BEGIN IMMEDIATE"), mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None), mock.call( - "CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', '+9oGc2VjcmV02gNYWFgw', '+hB1c2VyQGV4YW1wbGUuY29t', 'Tg==', 'Rg==')", + "CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', '+9oGc2VjcmV02gNYWFgw', '+hB1c2VyQGV4YW1wbGUuY29t', 'Tg==', 'Rg==', 'Tg==')", ), mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None), ], @@ -224,7 +225,7 @@ def test_credentials() -> None: mock.call("BEGIN IMMEDIATE"), mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None), mock.call( - "CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'VA==')", + "CREATE VIRTUAL TABLE \"https://docs.google.com/spreadsheets/d/1\" USING GSheetsAPI('+ihodHRwczovL2RvY3MuZ29vZ2xlLmNvbS9zcHJlYWRzaGVldHMvZC8x', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'Tg==', 'VA==', 'Tg==')", ), mock.call('SELECT 1 FROM "https://docs.google.com/spreadsheets/d/1"', None), ], @@ -2360,3 +2361,30 @@ def test_get_cost(mocker: MockerFixture) -> None: ) == 3022 ) + + +def test_session_verify( + mocker: MockerFixture, + simple_sheet_adapter: requests_mock.Adapter, +) -> None: + """ + Test setting ``verify`` in the session. + """ + session = requests.Session() + session.mount("https://", simple_sheet_adapter) + mocker.patch( + "shillelagh.adapters.api.gsheets.adapter.AuthorizedSession", + return_value=session, + ) + mocker.patch( + "shillelagh.adapters.api.gsheets.adapter.get_credentials", + return_value="SECRET", + ) + + engine = create_engine("gsheets://", session_verify=False) + connection = engine.connect() + + sql = '''SELECT * FROM "https://docs.google.com/spreadsheets/d/1/edit#gid=0"''' + connection.execute(text(sql)) + + assert session.verify is False diff --git a/tests/backends/apsw/dialects/gsheets_test.py b/tests/backends/apsw/dialects/gsheets_test.py index 635538fa..b87fee6c 100644 --- a/tests/backends/apsw/dialects/gsheets_test.py +++ b/tests/backends/apsw/dialects/gsheets_test.py @@ -39,6 +39,7 @@ def test_gsheets_dialect() -> None: "subject": None, "catalog": {}, "app_default_credentials": False, + "session_verify": None, }, }, "safe": True, @@ -63,6 +64,7 @@ def test_gsheets_dialect() -> None: "subject": "user@example.com", "catalog": {}, "app_default_credentials": False, + "session_verify": None, }, }, "safe": True, @@ -88,6 +90,7 @@ def test_gsheets_dialect() -> None: "subject": "user@example.com", "catalog": {"public_sheet": "https://example.com/"}, "app_default_credentials": False, + "session_verify": None, }, }, "safe": True, @@ -111,6 +114,7 @@ def test_gsheets_dialect() -> None: "subject": None, "catalog": {}, "app_default_credentials": True, + "session_verify": None, }, }, "safe": True, @@ -121,6 +125,28 @@ def test_gsheets_dialect() -> None: mock_dbapi_connection = mock.MagicMock() assert dialect.get_schema_names(mock_dbapi_connection) == [] + dialect = APSWGSheetsDialect(session_verify=False) + assert dialect.create_connect_args(make_url("gsheets://")) == ( + (), + { + "path": ":memory:", + "adapters": ["gsheetsapi"], + "adapter_kwargs": { + "gsheetsapi": { + "access_token": None, + "service_account_file": None, + "service_account_info": None, + "subject": None, + "catalog": {}, + "app_default_credentials": False, + "session_verify": False, + }, + }, + "safe": True, + "isolation_level": None, + }, + ) + def test_get_table_names(mocker: MockerFixture, requests_mock: Mocker) -> None: """