Skip to content

Commit

Permalink
SNOW-1058245-sqlalchemy-20-support: cast connectio query params to ty…
Browse files Browse the repository at this point in the history
…pes defined in connector
  • Loading branch information
sfc-gh-mraba committed Jul 2, 2024
1 parent b0bcb75 commit fdb1925
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 54 deletions.
48 changes: 23 additions & 25 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import operator
from collections import defaultdict
from functools import reduce
from typing import Any
from urllib.parse import unquote_plus

import sqlalchemy.types as sqltypes
Expand Down Expand Up @@ -63,7 +64,7 @@
_CUSTOM_Float,
_CUSTOM_Time,
)
from .util import _update_connection_application_name, parse_url_boolean
from .util import parse_url_boolean, parse_url_integer

colspecs = {
Date: _CUSTOM_Date,
Expand Down Expand Up @@ -203,6 +204,26 @@ def import_dbapi(cls):

return connector

@staticmethod
def parse_query_param_type(name: str, value: Any) -> Any:
"""Cast param value if possible to type defined in connector-python."""
if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)):
return value

_, expected_type = maybe_type_configuration
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)

if isinstance(value, expected_type):
return value

elif bool in expected_type:
return parse_url_boolean(value)
elif int in expected_type:
return parse_url_integer(value)
else:
return value

def create_connect_args(self, url: URL):
opts = url.translate_connect_args(username="user")
if "database" in opts:
Expand Down Expand Up @@ -239,30 +260,7 @@ def create_connect_args(self, url: URL):

# URL sets the query parameter values as strings, we need to cast to expected types when necessary
for name, value in query.items():
maybe_type_configuration = DEFAULT_CONFIGURATION.get(name)
if (
not maybe_type_configuration
): # if the parameter is not found in the type mapping, pass it through as a string
opts[name] = value
continue

(_, expected_type) = maybe_type_configuration
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)

if isinstance(
value, expected_type
): # if the expected type is str, pass it through as a string
opts[name] = value

elif (
bool in expected_type
): # if the expected type is bool, parse it and pass as a boolean
opts[name] = parse_url_boolean(value)
else:
# TODO: other types like int are stil passed through as string
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447
opts[name] = value
opts[name] = self.parse_query_param_type(name, value)

return ([], opts)

Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/sqlalchemy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def parse_url_boolean(value: str) -> bool:
raise ValueError(f"Invalid boolean value detected: '{value}'")


def parse_url_integer(value: str) -> int:
try:
return int(value)
except ValueError as e:
raise ValueError(f"Invalid int value detected: '{value}") from e


# handle Snowflake BCR bcr-1057
# the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState
# which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that
Expand Down
45 changes: 16 additions & 29 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
inspect,
text,
)
from sqlalchemy.exc import DBAPIError, NoSuchTableError
from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError
from sqlalchemy.sql import and_, not_, or_, select

import snowflake.connector.errors
Expand Down Expand Up @@ -1059,28 +1059,15 @@ def harass_inspector():
assert outcome


@pytest.mark.timeout(15)
def test_region():
engine = create_engine(
URL(
user="testuser",
password="testpassword",
account="testaccount",
region="eu-central-1",
login_timeout=5,
)
)
try:
engine.connect()
pytest.fail("should not run")
except Exception as ex:
assert ex.orig.errno == 250001
assert "Failed to connect to DB" in ex.orig.msg
assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg


@pytest.mark.timeout(15)
def test_azure():
@pytest.mark.timeout(10)
@pytest.mark.parametrize(
"region",
(
pytest.param("eu-central-1", id="region"),
pytest.param("east-us-2.azure", id="azure"),
),
)
def test_connection_timeout_error(region):
engine = create_engine(
URL(
user="testuser",
Expand All @@ -1090,13 +1077,13 @@ def test_azure():
login_timeout=5,
)
)
try:

with pytest.raises(OperationalError) as excinfo:
engine.connect()
pytest.fail("should not run")
except Exception as ex:
assert ex.orig.errno == 250001
assert "Failed to connect to DB" in ex.orig.msg
assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg

assert excinfo.value.orig.errno == 250001
assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg
assert region not in excinfo.value.orig.msg


def test_load_dialect():
Expand Down

0 comments on commit fdb1925

Please sign in to comment.