Skip to content

Commit

Permalink
SNOW-592647 consolidate definitions and resolve circular dependency i…
Browse files Browse the repository at this point in the history
…ssues (#1158)
  • Loading branch information
sfc-gh-mkeller authored Jun 8, 2022
1 parent 2d78cb0 commit 2ccd3c8
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 99 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne


- v2.7.9(Unreleased)
`

- Fixed a bug where errors raised during get_results_from_sfqid() were missing errno
- Fixed a bug where empty results containing GEOGRAPHY type raised IndexError


- v2.7.8(May 28,2022)
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import urllib.request
from typing import Any

from snowflake.connector.constants import UTF8
from . import constants

IS_LINUX = platform.system() == "Linux"
IS_WINDOWS = platform.system() == "Windows"
Expand Down Expand Up @@ -111,7 +111,7 @@ def PKCS5_PAD(value: bytes, block_size: int) -> bytes:
[
value,
(block_size - len(value) % block_size)
* chr(block_size - len(value) % block_size).encode(UTF8),
* chr(block_size - len(value) % block_size).encode(constants.UTF8),
]
)

Expand Down
85 changes: 65 additions & 20 deletions src/snowflake/connector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@

from collections import defaultdict
from enum import Enum, auto, unique
from typing import Any, DefaultDict, NamedTuple
from typing import Any, Callable, DefaultDict, NamedTuple

from .options import installed_pandas
from .options import pyarrow as pa

if installed_pandas:
DataType = pa.DataType
else:
DataType = None


DBAPI_TYPE_STRING = 0
DBAPI_TYPE_BINARY = 1
Expand All @@ -17,25 +26,61 @@
class FieldType(NamedTuple):
name: str
dbapi_type: list[int]


FIELD_TYPES: list[FieldType] = [
FieldType(name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER]),
FieldType(name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER]),
FieldType(name="TEXT", dbapi_type=[DBAPI_TYPE_STRING]),
FieldType(name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
FieldType(name="TIMESTAMP", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
FieldType(name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY]),
FieldType(name="TIMESTAMP_LTZ", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
FieldType(name="TIMESTAMP_TZ", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
FieldType(name="TIMESTAMP_NTZ", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
FieldType(name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY]),
FieldType(name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY]),
FieldType(name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY]),
FieldType(name="TIME", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
FieldType(name="BOOLEAN", dbapi_type=[]),
FieldType(name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING]),
]
pa_type: Callable[[], DataType]


# This type mapping holds column type definitions.
# Be careful to not change the ordering as the index is what Snowflake
# gives to as schema
FIELD_TYPES: tuple[FieldType] = (
FieldType(name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.int64()),
FieldType(
name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.float64()
),
FieldType(name="TEXT", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()),
FieldType(
name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.date64()
),
FieldType(
name="TIMESTAMP",
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
pa_type=lambda: pa.time64("ns"),
),
FieldType(
name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
),
FieldType(
name="TIMESTAMP_LTZ",
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
pa_type=lambda: pa.timestamp("ns"),
),
FieldType(
name="TIMESTAMP_TZ",
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
pa_type=lambda: pa.timestamp("ns"),
),
FieldType(
name="TIMESTAMP_NTZ",
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
pa_type=lambda: pa.timestamp("ns"),
),
FieldType(
name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
),
FieldType(
name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
),
FieldType(
name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.binary()
),
FieldType(
name="TIME", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.time64("ns")
),
FieldType(name="BOOLEAN", dbapi_type=[], pa_type=lambda: pa.bool_()),
FieldType(
name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()
),
)

FIELD_NAME_TO_ID: DefaultDict[Any, int] = defaultdict(int)
FIELD_ID_TO_NAME: DefaultDict[int, str] = defaultdict(str)
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from snowflake.connector.result_batch import create_batches_from_response
from snowflake.connector.result_set import ResultSet

from . import compat
from .bind_upload_agent import BindUploadAgent, BindUploadError
from .compat import BASE_EXCEPTION_CLASS
from .constants import (
FIELD_NAME_TO_ID,
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(
def __del__(self) -> None: # pragma: no cover
try:
self.close()
except BASE_EXCEPTION_CLASS as e:
except compat.BASE_EXCEPTION_CLASS as e:
if logger.getEffectiveLevel() <= logging.INFO:
logger.info(e)

Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def telemetry_msg(self) -> str | None:

def generate_telemetry_exception_data(self) -> dict[str, str]:
"""Generate the data to send through telemetry."""

telemetry_data = {
TelemetryField.KEY_DRIVER_TYPE.value: CLIENT_NAME,
TelemetryField.KEY_DRIVER_VERSION.value: SNOWFLAKE_CONNECTOR_VERSION,
Expand All @@ -146,6 +147,7 @@ def send_exception_telemetry(
telemetry_data: dict[str, str],
) -> None:
"""Send telemetry data by in-band telemetry if it is enabled, otherwise send through out-of-band telemetry."""

if (
connection is not None
and connection.telemetry_enabled
Expand All @@ -164,6 +166,7 @@ def send_exception_telemetry(
logger.debug("Cursor failed to log to telemetry.", exc_info=True)
elif connection is None:
# Send with out-of-band telemetry

telemetry_oob = TelemetryService.get_instance()
telemetry_oob.log_general_exception(self.__class__.__name__, telemetry_data)

Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pkg_resources

from .errors import MissingDependencyError
from . import errors

logger = getLogger(__name__)

Expand All @@ -35,7 +35,7 @@ class MissingOptionalDependency:
_dep_name = "not set"

def __getattr__(self, item):
raise MissingDependencyError(self._dep_name)
raise errors.MissingDependencyError(self._dep_name)


class MissingPandas(MissingOptionalDependency):
Expand Down
45 changes: 12 additions & 33 deletions src/snowflake/connector/result_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .arrow_context import ArrowConverterContext
from .compat import OK, UNAUTHORIZED, urlparse
from .constants import IterUnit
from .constants import FIELD_TYPES, IterUnit
from .errorcode import ER_FAILED_TO_CONVERT_ROW_TO_PYTHON_TYPE, ER_NO_PYARROW
from .errors import Error, InterfaceError, NotSupportedError, ProgrammingError
from .network import (
Expand All @@ -25,6 +25,7 @@
raise_okta_unauthorized_error,
)
from .options import installed_pandas, pandas
from .options import pyarrow as pa
from .secret_detector import SecretDetector
from .time_util import DecorrelateJitterBackoff, TimerContextManager
from .vendored import requests
Expand All @@ -40,20 +41,13 @@
from .cursor import ResultMetadata, SnowflakeCursor
from .vendored.requests import Response

if installed_pandas:
from pyarrow import DataType, Table
from pyarrow import binary as pa_bin
from pyarrow import bool_ as pa_bool
from pyarrow import date64 as pa_date64
from pyarrow import field
from pyarrow import float64 as pa_flt64
from pyarrow import int64 as pa_int64
from pyarrow import schema
from pyarrow import string as pa_str
from pyarrow import time64 as pa_time64
from pyarrow import timestamp as pa_ts
else:
DataType, Table = None, None
if installed_pandas:
DataType = pa.DataType
Table = pa.Table
else:
DataType = None
Table = None


# emtpy pyarrow type array corresponding to FIELD_TYPES
FIELD_TYPE_TO_PA_TYPE: list[DataType] = []
Expand Down Expand Up @@ -655,26 +649,11 @@ def _create_empty_table(self) -> Table:
"""Returns emtpy Arrow table based on schema"""
if installed_pandas:
# initialize pyarrow type array corresponding to FIELD_TYPES
FIELD_TYPE_TO_PA_TYPE = [
pa_int64(),
pa_flt64(),
pa_str(),
pa_date64(),
pa_time64("ns"),
pa_str(),
pa_ts("ns"),
pa_ts("ns"),
pa_ts("ns"),
pa_str(),
pa_str(),
pa_bin(),
pa_time64("ns"),
pa_bool(),
]
FIELD_TYPE_TO_PA_TYPE = [e.pa_type() for e in FIELD_TYPES]
fields = [
field(s.name, FIELD_TYPE_TO_PA_TYPE[s.type_code]) for s in self.schema
pa.field(s.name, FIELD_TYPE_TO_PA_TYPE[s.type_code]) for s in self.schema
]
return schema(fields).empty_table()
return pa.schema(fields).empty_table()

def to_arrow(self, connection: SnowflakeConnection | None = None) -> Table:
"""Returns this batch as a pyarrow Table"""
Expand Down
42 changes: 17 additions & 25 deletions test/integ/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,41 +572,33 @@ def test_variant(conn, db_parameters):


@pytest.mark.skipolddriver
def test_geography(conn, db_parameters):
def test_geography(conn_cnx):
"""Variant including JSON object."""
name_geo = random_string(5, "test_geography_")
with conn() as cnx:
cnx.cursor().execute(
f"""\
create table {name_geo} (geo geography)
"""
)
cnx.cursor().execute(
f"""\
insert into {name_geo} values ('POINT(0 0)'), ('LINESTRING(1 1, 2 2)')
"""
)
expected_data = [
{"coordinates": [0, 0], "type": "Point"},
{"coordinates": [[1, 1], [2, 2]], "type": "LineString"},
]

try:
with conn() as cnx:
c = cnx.cursor()
c.execute("alter session set GEOGRAPHY_OUTPUT_FORMAT='geoJson'")
with conn_cnx(
session_parameters={
"GEOGRAPHY_OUTPUT_FORMAT": "geoJson",
},
) as cnx:
with cnx.cursor() as cur:
cur.execute(f"create temporary table {name_geo} (geo geography)")
cur.execute(
f"insert into {name_geo} values ('POINT(0 0)'), ('LINESTRING(1 1, 2 2)')"
)
expected_data = [
{"coordinates": [0, 0], "type": "Point"},
{"coordinates": [[1, 1], [2, 2]], "type": "LineString"},
]

with cnx.cursor() as cur:
# Test with GEOGRAPHY return type
result = c.execute(f"select * from {name_geo}")
result = cur.execute(f"select * from {name_geo}")
metadata = result.description
assert FIELD_ID_TO_NAME[metadata[0].type_code] == "GEOGRAPHY"
data = result.fetchall()
for raw_data in data:
row = json.loads(raw_data[0])
assert row in expected_data
finally:
with conn() as cnx:
cnx.cursor().execute(f"drop table {name_geo}")


def test_invalid_bind_data_type(conn_cnx):
Expand Down
23 changes: 10 additions & 13 deletions test/integ/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import snowflake.connector
import snowflake.connector.dbapi
from snowflake.connector import dbapi, errorcode, errors
from snowflake.connector.compat import BASE_EXCEPTION_CLASS

from ..randomize import random_string

Expand Down Expand Up @@ -273,20 +272,18 @@ def test_close(db_parameters):
# errorcode.ER_CURSOR_IS_CLOSED),'cursor.close() called twice in a row')

# calling cursor.execute after connection is closed should raise an error
try:
with pytest.raises(errors.Error) as e:
cur.execute(f"create or replace table {TABLE1} (name string)")
except BASE_EXCEPTION_CLASS as error:
assert (
error.errno == errorcode.ER_CURSOR_IS_CLOSED
), "cursor.execute() called twice in a row"
assert (
e.value.errno == errorcode.ER_CURSOR_IS_CLOSED
), "cursor.execute() called twice in a row"

# try to create a cursor on a closed connection
try:
con.cursor()
except BASE_EXCEPTION_CLASS as error:
assert (
error.errno == errorcode.ER_CONNECTION_IS_CLOSED
), "tried to create a cursor on a closed cursor"
# try to create a cursor on a closed connection
with pytest.raises(errors.Error) as e:
con.cursor()
assert (
e.value.errno == errorcode.ER_CONNECTION_IS_CLOSED
), "tried to create a cursor on a closed cursor"


def test_execute(conn_local):
Expand Down
2 changes: 1 addition & 1 deletion test/integ/test_put_get_user_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _put_list_rm_files_in_stage(tmpdir, conn_cnx, elem):

from io import open

from snowflake.connector.compat import UTF8
from snowflake.connector.constants import UTF8

tmp_dir = str(tmpdir.mkdir("data"))
data_file = os.path.join(tmp_dir, data_file_name)
Expand Down

0 comments on commit 2ccd3c8

Please sign in to comment.