Skip to content

Commit

Permalink
SNOW-1664063: sync main branch changes into async part (#2081)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling authored Oct 25, 2024
1 parent 2ce1be6 commit d5a8592
Show file tree
Hide file tree
Showing 32 changed files with 507 additions and 54 deletions.
4 changes: 4 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

# Release Notes

- v3.12.3(TBD)
- Improved the error message for SSL-related issues to provide clearer guidance when an SSL error occurs.
- Improved error message for SQL execution cancellations caused by timeout.

- v3.12.2(September 11,2024)
- Improved error handling for asynchronous queries, providing more detailed and informative error messages when an async query fails.
- Improved inference of top-level domains for accounts specifying a region in China, now defaulting to snowflakecomputing.cn.
Expand Down
56 changes: 56 additions & 0 deletions src/snowflake/connector/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import string
from enum import Enum
from random import choice
from threading import Timer


class TempObjectType(Enum):
TABLE = "TABLE"
VIEW = "VIEW"
STAGE = "STAGE"
FUNCTION = "FUNCTION"
FILE_FORMAT = "FILE_FORMAT"
QUERY_TAG = "QUERY_TAG"
COLUMN = "COLUMN"
PROCEDURE = "PROCEDURE"
TABLE_FUNCTION = "TABLE_FUNCTION"
DYNAMIC_TABLE = "DYNAMIC_TABLE"
AGGREGATE_FUNCTION = "AGGREGATE_FUNCTION"
CTE = "CTE"


TEMP_OBJECT_NAME_PREFIX = "SNOWPARK_TEMP_"
ALPHANUMERIC = string.digits + string.ascii_lowercase
TEMPORARY_STRING = "TEMP"
SCOPED_TEMPORARY_STRING = "SCOPED TEMPORARY"
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = (
"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS"
)


def generate_random_alphanumeric(length: int = 10) -> str:
return "".join(choice(ALPHANUMERIC) for _ in range(length))


def random_name_for_temp_object(object_type: TempObjectType) -> str:
return f"{TEMP_OBJECT_NAME_PREFIX}{object_type.value}_{generate_random_alphanumeric().upper()}"


def get_temp_type_for_object(use_scoped_temp_objects: bool) -> str:
return SCOPED_TEMPORARY_STRING if use_scoped_temp_objects else TEMPORARY_STRING


class _TrackedQueryCancellationTimer(Timer):
def __init__(self, interval, function, args=None, kwargs=None):
super().__init__(interval, function, args, kwargs)
self.executed = False

def run(self):
super().run()
self.executed = True
7 changes: 5 additions & 2 deletions src/snowflake/connector/aio/_azure_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import aiohttp

from ..azure_storage_client import AzureCredentialFilter
from ..azure_storage_client import (
SnowflakeAzureRestClient as SnowflakeAzureRestClientSync,
)
Expand All @@ -25,14 +26,16 @@
if TYPE_CHECKING: # pragma: no cover
from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential

logger = getLogger(__name__)

from ..azure_storage_client import (
ENCRYPTION_DATA,
MATDESC,
TOKEN_EXPIRATION_ERR_MESSAGE,
)

logger = getLogger(__name__)

getLogger("aiohttp").addFilter(AzureCredentialFilter())


class SnowflakeAzureRestClient(
SnowflakeStorageClientAsync, SnowflakeAzureRestClientSync
Expand Down
7 changes: 6 additions & 1 deletion src/snowflake/connector/aio/_build_upload_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING, cast

from snowflake.connector import Error
from snowflake.connector._utils import get_temp_type_for_object
from snowflake.connector.bind_upload_agent import BindUploadAgent as BindUploadAgentSync
from snowflake.connector.errors import BindUploadError

Expand All @@ -30,7 +31,11 @@ def __init__(
self.cursor = cast("SnowflakeCursor", cursor)

async def _create_stage(self) -> None:
await self.cursor.execute(self._CREATE_STAGE_STMT)
create_stage_sql = (
f"create or replace {get_temp_type_for_object(self._use_scoped_temp_object)} stage {self._STAGE_NAME} "
"file_format=(type=csv field_optionally_enclosed_by='\"')"
)
await self.cursor.execute(create_stage_sql)

async def upload(self) -> None:
try:
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..connection import _get_private_bytes_from_file
from ..connection_diagnostic import ConnectionDiagnostic
from ..constants import (
_CONNECTIVITY_ERR_MSG,
ENV_VAR_PARTNER,
PARAMETER_AUTOCOMMIT,
PARAMETER_CLIENT_PREFETCH_THREADS,
Expand Down Expand Up @@ -443,6 +444,8 @@ async def _authenticate(self, auth_instance: AuthByPlugin):
)
except OperationalError as auth_op:
if auth_op.errno == ER_FAILED_TO_CONNECT_TO_DB:
if _CONNECTIVITY_ERR_MSG in e.msg:
auth_op.msg += f"\n{_CONNECTIVITY_ERR_MSG}"
raise auth_op from e
logger.debug("Continuing authenticator specific timeout handling")
continue
Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ async def _timebomb_task(self, timeout, query):
logger.debug("started timebomb in %ss", timeout)
await asyncio.sleep(timeout)
await self.__cancel_query(query)
return True
except asyncio.CancelledError:
logger.debug("cancelled timebomb in timebomb task")
return False

async def __cancel_query(self, query) -> None:
if self._sequence_counter >= 0 and not self.is_closed():
Expand Down Expand Up @@ -284,7 +286,10 @@ def interrupt_handler(*_): # pragma: no cover
)
if self._timebomb is not None:
self._timebomb.cancel()
self._timebomb = None
try:
await self._timebomb
except asyncio.CancelledError:
pass
logger.debug("cancelled timebomb in finally")

if "data" in ret and "parameters" in ret["data"]:
Expand Down Expand Up @@ -674,6 +679,11 @@ async def execute(
logger.debug(ret)
err = ret["message"]
code = ret.get("code", -1)
if self._timebomb and self._timebomb.result():
err = (
f"SQL execution was cancelled by the client due to a timeout. "
f"Error message received from the server: {err}"
)
if "data" in ret:
err += ret["data"].get("errorMessage", "")
errvalue = {
Expand Down Expand Up @@ -1067,6 +1077,7 @@ async def wait_until_ready() -> None:
self._prefetch_hook = wait_until_ready

async def query_result(self, qid: str) -> SnowflakeCursor:
"""Query the result of a previously executed query."""
url = f"/queries/{qid}/result"
ret = await self._connection.rest.request(url=url, method="get")
self._sfqid = (
Expand Down
16 changes: 14 additions & 2 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
urlparse,
)
from ..constants import (
_CONNECTIVITY_ERR_MSG,
HTTP_HEADER_ACCEPT,
HTTP_HEADER_CONTENT_TYPE,
HTTP_HEADER_SERVICE_NAME,
Expand Down Expand Up @@ -798,8 +799,19 @@ async def _request_exec(
finally:
raw_ret.close() # ensure response is closed
except aiohttp.ClientSSLError as se:
logger.debug("Hit non-retryable SSL error, %s", str(se))

msg = f"Hit non-retryable SSL error, {str(se)}.\n{_CONNECTIVITY_ERR_MSG}"
logger.debug(msg)
# the following code is for backward compatibility with old versions of python connector which calls
# self._handle_unknown_error to process SSLError
Error.errorhandler_wrapper(
self._connection,
None,
OperationalError,
{
"msg": msg,
"errno": ER_FAILED_TO_REQUEST,
},
)
# TODO: sync feature parity, aiohttp network error handling
except (
BadStatusLine,
Expand Down
13 changes: 10 additions & 3 deletions src/snowflake/connector/aio/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from typing import TYPE_CHECKING, Any, Callable

from ...auth import Auth as AuthSync
from ...auth._auth import ID_TOKEN, MFA_TOKEN, delete_temporary_credential
from ...auth._auth import (
AUTHENTICATION_REQUEST_KEY_WHITELIST,
ID_TOKEN,
MFA_TOKEN,
delete_temporary_credential,
)
from ...compat import urlencode
from ...constants import (
HTTP_HEADER_ACCEPT,
Expand Down Expand Up @@ -103,7 +108,6 @@ async def authenticate(

body = copy.deepcopy(body_template)
# updating request body
logger.debug("assertion content: %s", auth_instance.assertion_content)
await auth_instance.update_body(body)

logger.debug(
Expand Down Expand Up @@ -141,7 +145,10 @@ async def authenticate(

logger.debug(
"body['data']: %s",
{k: v for (k, v) in body["data"].items() if k != "PASSWORD"},
{
k: v if k in AUTHENTICATION_REQUEST_KEY_WHITELIST else "******"
for (k, v) in body["data"].items()
},
)

try:
Expand Down
18 changes: 16 additions & 2 deletions src/snowflake/connector/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@
ID_TOKEN = "ID_TOKEN"
MFA_TOKEN = "MFATOKEN"

AUTHENTICATION_REQUEST_KEY_WHITELIST = {
"ACCOUNT_NAME",
"AUTHENTICATOR",
"CLIENT_APP_ID",
"CLIENT_APP_VERSION",
"CLIENT_ENVIRONMENT",
"EXT_AUTHN_DUO_METHOD",
"LOGIN_NAME",
"SESSION_PARAMETERS",
"SVN_REVISION",
}


class Auth:
"""Snowflake Authenticator."""
Expand Down Expand Up @@ -205,7 +217,6 @@ def authenticate(

body = copy.deepcopy(body_template)
# updating request body
logger.debug("assertion content: %s", auth_instance.assertion_content)
auth_instance.update_body(body)

logger.debug(
Expand Down Expand Up @@ -243,7 +254,10 @@ def authenticate(

logger.debug(
"body['data']: %s",
{k: v for (k, v) in body["data"].items() if k != "PASSWORD"},
{
k: v if k in AUTHENTICATION_REQUEST_KEY_WHITELIST else "******"
for (k, v) in body["data"].items()
},
)

try:
Expand Down
18 changes: 17 additions & 1 deletion src/snowflake/connector/azure_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import xml.etree.ElementTree as ET
from datetime import datetime, timezone
from logging import getLogger
from logging import Filter, getLogger
from random import choice
from string import hexdigits
from typing import TYPE_CHECKING, Any, NamedTuple
Expand Down Expand Up @@ -39,6 +39,22 @@ class AzureLocation(NamedTuple):
MATDESC = "x-ms-meta-matdesc"


class AzureCredentialFilter(Filter):
LEAKY_FMT = '%s://%s:%s "%s %s %s" %s %s'

def filter(self, record):
if record.msg == AzureCredentialFilter.LEAKY_FMT and len(record.args) == 8:
record.args = (
record.args[:4] + (record.args[4].split("?")[0],) + record.args[5:]
)
return True


getLogger("snowflake.connector.vendored.urllib3.connectionpool").addFilter(
AzureCredentialFilter()
)


class SnowflakeAzureRestClient(SnowflakeStorageClient):
def __init__(
self,
Expand Down
25 changes: 19 additions & 6 deletions src/snowflake/connector/bind_upload_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from logging import getLogger
from typing import TYPE_CHECKING

from ._utils import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
get_temp_type_for_object,
)
from .errors import BindUploadError, Error

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -19,11 +23,6 @@


class BindUploadAgent:
_STAGE_NAME = "SYSTEMBIND"
_CREATE_STAGE_STMT = (
f"create or replace temporary stage {_STAGE_NAME} "
"file_format=(type=csv field_optionally_enclosed_by='\"')"
)

def __init__(
self,
Expand All @@ -38,13 +37,27 @@ def __init__(
rows: Rows of binding parameters in CSV format.
stream_buffer_size: Size of each file, default to 10MB.
"""
self._use_scoped_temp_object = (
cursor.connection._session_parameters.get(
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, False
)
if cursor.connection._session_parameters
else False
)
self._STAGE_NAME = (
"SNOWPARK_TEMP_STAGE_BIND" if self._use_scoped_temp_object else "SYSTEMBIND"
)
self.cursor = cursor
self.rows = rows
self._stream_buffer_size = stream_buffer_size
self.stage_path = f"@{self._STAGE_NAME}/{uuid.uuid4().hex}"

def _create_stage(self) -> None:
self.cursor.execute(self._CREATE_STAGE_STMT)
create_stage_sql = (
f"create or replace {get_temp_type_for_object(self._use_scoped_temp_object)} stage {self._STAGE_NAME} "
"file_format=(type=csv field_optionally_enclosed_by='\"')"
)
self.cursor.execute(create_stage_sql)

def upload(self) -> None:
try:
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .config_manager import CONFIG_MANAGER, _get_default_connection_params
from .connection_diagnostic import ConnectionDiagnostic
from .constants import (
_CONNECTIVITY_ERR_MSG,
_DOMAIN_NAME_MAP,
ENV_VAR_PARTNER,
PARAMETER_AUTOCOMMIT,
Expand Down Expand Up @@ -1455,6 +1456,8 @@ def _authenticate(self, auth_instance: AuthByPlugin):
)
except OperationalError as auth_op:
if auth_op.errno == ER_FAILED_TO_CONNECT_TO_DB:
if _CONNECTIVITY_ERR_MSG in e.msg:
auth_op.msg += f"\n{_CONNECTIVITY_ERR_MSG}"
raise auth_op from e
logger.debug("Continuing authenticator specific timeout handling")
continue
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/connector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,9 @@ class IterUnit(Enum):


_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}

_CONNECTIVITY_ERR_MSG = (
"Verify that the hostnames and port numbers in SYSTEM$ALLOWLIST are added to your firewall's allowed list."
"\nTo further troubleshoot your connection you may reference the following article: "
"https://docs.snowflake.com/en/user-guide/client-connectivity-troubleshooting/overview."
)
Loading

0 comments on commit d5a8592

Please sign in to comment.