Skip to content

Commit

Permalink
SNOW-1757241: migrate all integ test (#2076)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling authored Oct 24, 2024
1 parent 6c5794c commit 0c53f44
Show file tree
Hide file tree
Showing 35 changed files with 4,529 additions and 13 deletions.
28 changes: 25 additions & 3 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import itertools
import json
import logging
import re
import uuid
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -163,7 +164,7 @@ def __init__(
self._ocsp_mode = (
self._connection._ocsp_mode() if self._connection else OCSPMode.FAIL_OPEN
)
if self._connection.proxy_host:
if self._connection and self._connection.proxy_host:
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
else:
self._get_proxy_headers = lambda _: None
Expand Down Expand Up @@ -416,6 +417,7 @@ async def _get_request(
headers: dict[str, str],
token: str = None,
timeout: int | None = None,
is_fetch_query_status: bool = False,
) -> dict[str, Any]:
if "Content-Encoding" in headers:
del headers["Content-Encoding"]
Expand All @@ -429,6 +431,7 @@ async def _get_request(
headers,
timeout=timeout,
token=token,
is_fetch_query_status=is_fetch_query_status,
)
if ret.get("code") == SESSION_EXPIRED_GS_CODE:
try:
Expand All @@ -443,7 +446,12 @@ async def _get_request(
)
)
if ret.get("success"):
return await self._get_request(url, headers, token=self.token)
return await self._get_request(
url,
headers,
token=self.token,
is_fetch_query_status=is_fetch_query_status,
)

return ret

Expand Down Expand Up @@ -517,7 +525,13 @@ async def _post_request(
result_url = ret["data"]["getResultUrl"]
logger.debug("ping pong starting...")
ret = await self._get_request(
result_url, headers, token=self.token, timeout=timeout
result_url,
headers,
token=self.token,
timeout=timeout,
is_fetch_query_status=bool(
re.match(r"^/queries/.+/result$", result_url)
),
)
logger.debug("ret[code] = %s", ret.get("code", "N/A"))
logger.debug("ping pong done")
Expand Down Expand Up @@ -603,6 +617,7 @@ async def _request_exec_wrapper(

full_url = retry_ctx.add_retry_params(full_url)
full_url = SnowflakeRestful.add_request_guid(full_url)
is_fetch_query_status = kwargs.pop("is_fetch_query_status", False)
try:
return_object = await self._request_exec(
session=session,
Expand All @@ -615,6 +630,13 @@ async def _request_exec_wrapper(
)
if return_object is not None:
return return_object
if is_fetch_query_status:
err_msg = (
"fetch query status failed and http request returned None, this"
" is usually caused by transient network failures, retrying..."
)
logger.info(err_msg)
raise RetryRequest(err_msg)
self._handle_unknown_error(method, full_url, headers, data, conn)
return {}
except RetryRequest as e:
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/aio/_s3_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def generate_authenticated_url_and_args_v4() -> tuple[str, dict[str, bytes]]:
if payload:
rest_args["data"] = payload

# ignore_content_encoding is removed because it
# does not apply to asyncio
if ignore_content_encoding:
rest_args["auto_decompress"] = False

return url, rest_args

Expand Down
3 changes: 3 additions & 0 deletions test/integ/aio/lambda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
25 changes: 25 additions & 0 deletions test/integ/aio/lambda/test_basic_query_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python

#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#


async def test_connection(conn_cnx):
"""Test basic connection."""
async with conn_cnx() as cnx:
cur = cnx.cursor()
result = await (await cur.execute("select 1;")).fetchall()
assert result == [(1,)]


async def test_large_resultset(conn_cnx):
"""Test large resultset."""
async with conn_cnx() as cnx:
cur = cnx.cursor()
result = await (
await cur.execute(
"select seq8(), randstr(1000, random()) from table(generator(rowcount=>10000));"
)
).fetchall()
assert len(result) == 10000
3 changes: 3 additions & 0 deletions test/integ/aio/sso/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
187 changes: 187 additions & 0 deletions test/integ/aio/sso/test_connection_manual_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

# This test requires the SSO and Snowflake admin connection parameters.
#
# CONNECTION_PARAMETERS_SSO = {
# 'account': 'testaccount',
# 'user': '[email protected]',
# 'protocol': 'http',
# 'host': 'testaccount.reg.snowflakecomputing.com',
# 'port': '8082',
# 'authenticator': 'externalbrowser',
# 'timezone': 'UTC',
# }
#
# CONNECTION_PARAMETERS_ADMIN = { ... Snowflake admin ... }
import os
import sys

import pytest

import snowflake.connector.aio
from snowflake.connector.auth._auth import delete_temporary_credential

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

try:
from parameters import CONNECTION_PARAMETERS_SSO
except ImportError:
CONNECTION_PARAMETERS_SSO = {}

try:
from parameters import CONNECTION_PARAMETERS_ADMIN
except ImportError:
CONNECTION_PARAMETERS_ADMIN = {}

ID_TOKEN = "ID_TOKEN"


@pytest.fixture
async def token_validity_test_values(request):
async with snowflake.connector.aio.SnowflakeConnection(
**CONNECTION_PARAMETERS_ADMIN
) as cnx:
await cnx.cursor().execute(
"""
ALTER SYSTEM SET
MASTER_TOKEN_VALIDITY=60,
SESSION_TOKEN_VALIDITY=5,
ID_TOKEN_VALIDITY=60
"""
)
# ALLOW_UNPROTECTED_ID_TOKEN is going to be deprecated in the future
# cnx.cursor().execute("alter account testaccount set ALLOW_UNPROTECTED_ID_TOKEN=true;")
await cnx.cursor().execute("alter account testaccount set ALLOW_ID_TOKEN=true;")
await cnx.cursor().execute(
"alter account testaccount set ID_TOKEN_FEATURE_ENABLED=true;"
)

async def fin():
async with snowflake.connector.connect(**CONNECTION_PARAMETERS_ADMIN) as cnx:
await cnx.cursor().execute(
"""
ALTER SYSTEM SET
MASTER_TOKEN_VALIDITY=default,
SESSION_TOKEN_VALIDITY=default,
ID_TOKEN_VALIDITY=default
"""
)

request.addfinalizer(fin)
return None


@pytest.mark.skipif(
not (
CONNECTION_PARAMETERS_SSO
and CONNECTION_PARAMETERS_ADMIN
and delete_temporary_credential
),
reason="SSO and ADMIN connection parameters must be provided.",
)
async def test_connect_externalbrowser(token_validity_test_values):
"""SSO Id Token Cache tests. This test should only be ran if keyring optional dependency is installed.
In order to run this test, remove the above pytest.mark.skip annotation and run it. It will popup a windows once
but the rest connections should not create popups.
"""
delete_temporary_credential(
host=CONNECTION_PARAMETERS_SSO["host"],
user=CONNECTION_PARAMETERS_SSO["user"],
cred_type=ID_TOKEN,
) # delete existing temporary credential
CONNECTION_PARAMETERS_SSO["client_store_temporary_credential"] = True

# change database and schema to non-default one
print(
"[INFO] 1st connection gets id token and stores in the local cache (keychain/credential manager/cache file). "
"This popup a browser to SSO login"
)
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
await cnx.connect()
assert cnx.database == "TESTDB"
assert cnx.schema == "PUBLIC"
assert cnx.role == "SYSADMIN"
assert cnx.warehouse == "REGRESS"
ret = await (
await cnx.cursor().execute(
"select current_database(), current_schema(), "
"current_role(), current_warehouse()"
)
).fetchall()
assert ret[0][0] == "TESTDB"
assert ret[0][1] == "PUBLIC"
assert ret[0][2] == "SYSADMIN"
assert ret[0][3] == "REGRESS"
await cnx.close()

print(
"[INFO] 2nd connection reads the local cache and uses the id token. "
"This should not popups a browser."
)
CONNECTION_PARAMETERS_SSO["database"] = "testdb"
CONNECTION_PARAMETERS_SSO["schema"] = "testschema"
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
await cnx.connect()
print(
"[INFO] Running a 10 seconds query. If the session expires in 10 "
"seconds, the query should renew the token in the middle, "
"and the current objects should be refreshed."
)
await cnx.cursor().execute("select seq8() from table(generator(timelimit=>10))")
assert cnx.database == "TESTDB"
assert cnx.schema == "TESTSCHEMA"
assert cnx.role == "SYSADMIN"
assert cnx.warehouse == "REGRESS"

print("[INFO] Running a 1 second query. ")
await cnx.cursor().execute("select seq8() from table(generator(timelimit=>1))")
assert cnx.database == "TESTDB"
assert cnx.schema == "TESTSCHEMA"
assert cnx.role == "SYSADMIN"
assert cnx.warehouse == "REGRESS"

print(
"[INFO] Running a 90 seconds query. This pops up a browser in the "
"middle of the query."
)
await cnx.cursor().execute("select seq8() from table(generator(timelimit=>90))")
assert cnx.database == "TESTDB"
assert cnx.schema == "TESTSCHEMA"
assert cnx.role == "SYSADMIN"
assert cnx.warehouse == "REGRESS"

await cnx.close()

# change database and schema again to ensure they are overridden
CONNECTION_PARAMETERS_SSO["database"] = "testdb"
CONNECTION_PARAMETERS_SSO["schema"] = "testschema"
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
await cnx.connect()
assert cnx.database == "TESTDB"
assert cnx.schema == "TESTSCHEMA"
assert cnx.role == "SYSADMIN"
assert cnx.warehouse == "REGRESS"
await cnx.close()

async with snowflake.connector.aio.SnowflakeConnection(
**CONNECTION_PARAMETERS_ADMIN
) as cnx_admin:
# cnx_admin.cursor().execute("alter account testaccount set ALLOW_UNPROTECTED_ID_TOKEN=false;")
await cnx_admin.cursor().execute(
"alter account testaccount set ALLOW_ID_TOKEN=false;"
)
await cnx_admin.cursor().execute(
"alter account testaccount set ID_TOKEN_FEATURE_ENABLED=false;"
)
print(
"[INFO] Login again with ALLOW_UNPROTECTED_ID_TOKEN unset. Please make sure this pops up the browser"
)
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
await cnx.connect()
await cnx.close()
Loading

0 comments on commit 0c53f44

Please sign in to comment.