Skip to content

Commit

Permalink
SNOW-644849: Add telemetry about imported pacakages at runtime (#1236)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu authored Sep 1, 2022
1 parent 2c3f6d5 commit 4ccba6a
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 1 deletion.
39 changes: 38 additions & 1 deletion src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@
SnowflakeRestful,
)
from .sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED
from .telemetry import TelemetryClient
from .telemetry import (
TelemetryClient,
TelemetryData,
TelemetryField,
generate_telemetry_data,
)
from .telemetry_oob import TelemetryService
from .time_util import HeartBeatTimer, get_time_millis
from .util_text import construct_hostname, parse_account, split_statements
Expand Down Expand Up @@ -192,6 +197,10 @@ def DefaultConverterClass():
None,
(type(None), str),
), # Path to connection diag whitelist json
"log_imported_packages_in_telemetry": (
True,
bool,
), # Whether to log imported packages in telemetry
}

APPLICATION_RE = re.compile(r"[\w\d_]+")
Expand Down Expand Up @@ -292,6 +301,9 @@ def __init__(self, **kwargs):
self.connect(**kwargs)
self._telemetry = TelemetryClient(self._rest)

# get the imported modules from sys.modules
self._log_telemetry_imported_packages()

def __del__(self): # pragma: no cover
try:
self.close(retry=False)
Expand Down Expand Up @@ -1541,3 +1553,28 @@ def _all_async_queries_finished(self) -> bool:
not self.is_still_running(self.get_query_status(q)) for q in queries
)
return all(finished_async_queries)

def _log_telemetry_imported_packages(self) -> None:
if self._log_imported_packages_in_telemetry:
# filter out duplicates caused by submodules
# and internal modules with names starting with an underscore
imported_modules = {
k.split(".", maxsplit=1)[0]
for k in sys.modules.keys()
if not k.startswith("_")
}
ts = get_time_millis()
self._log_telemetry(
TelemetryData(
generate_telemetry_data(
from_dict={
TelemetryField.KEY_TYPE.value: TelemetryField.IMPORTED_PACKAGES.value,
TelemetryField.KEY_SOURCE.value: self.application
if self.application
else CLIENT_NAME,
TelemetryField.KEY_VALUE.value: str(imported_modules),
}
),
ts,
)
)
2 changes: 2 additions & 0 deletions src/snowflake/connector/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class TelemetryField(Enum):
# fetch_arrow_* usage
ARROW_FETCH_ALL = "client_fetch_arrow_all"
ARROW_FETCH_BATCHES = "client_fetch_arrow_batches"
# imported packages along with client
IMPORTED_PACKAGES = "client_imported_packages"
# Keys for telemetry data sent through either in-band or out-of-band telemetry
KEY_TYPE = "type"
KEY_SOURCE = "source"
Expand Down
78 changes: 78 additions & 0 deletions test/integ/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from snowflake.connector.errors import Error, ForbiddenError
from snowflake.connector.network import APPLICATION_SNOWSQL, ReauthenticationRequest
from snowflake.connector.sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED
from snowflake.connector.telemetry import TelemetryField

try: # pragma: no cover
from parameters import CONNECTION_PARAMETERS_ADMIN
Expand Down Expand Up @@ -1107,3 +1108,80 @@ def test_ocsp_cache_working(conn_cnx):
with conn_cnx() as cnx:
assert cnx
assert OCSP_CACHE.telemetry["hit"] + OCSP_CACHE.telemetry["miss"] > original_count


@pytest.mark.skipolddriver
def test_imported_packages_telemetry(conn_cnx, capture_sf_telemetry, db_parameters):
# these imports are not used but for testing
import html.parser # noqa: F401
import json # noqa: F401
import multiprocessing as mp # noqa: F401
from datetime import date # noqa: F401
from math import sqrt # noqa: F401

def check_packages(message: str, expected_packages: list[str]) -> bool:
return (
all([package in message for package in expected_packages])
and "__main__" not in message
)

packages = [
"pytest",
"unittest",
"json",
"multiprocessing",
"html",
"datetime",
"math",
]

with conn_cnx() as conn, capture_sf_telemetry.patch_connection(
conn, False
) as telemetry_test:
conn._log_telemetry_imported_packages()
assert len(telemetry_test.records) > 0
assert any(
[
t.message[TelemetryField.KEY_TYPE.value]
== TelemetryField.IMPORTED_PACKAGES.value
and CLIENT_NAME == t.message[TelemetryField.KEY_SOURCE.value]
and check_packages(t.message["value"], packages)
for t in telemetry_test.records
]
)

# test different application
new_application_name = "PythonSnowpark"
config = {
"user": db_parameters["user"],
"password": db_parameters["password"],
"host": db_parameters["host"],
"port": db_parameters["port"],
"account": db_parameters["account"],
"schema": db_parameters["schema"],
"database": db_parameters["database"],
"protocol": db_parameters["protocol"],
"timezone": "UTC",
"application": new_application_name,
}
with snowflake.connector.connect(
**config
) as conn, capture_sf_telemetry.patch_connection(conn, False) as telemetry_test:
conn._log_telemetry_imported_packages()
assert len(telemetry_test.records) > 0
assert any(
[
t.message[TelemetryField.KEY_TYPE.value]
== TelemetryField.IMPORTED_PACKAGES.value
and new_application_name == t.message[TelemetryField.KEY_SOURCE.value]
for t in telemetry_test.records
]
)

# test opt out
config["log_imported_packages_in_telemetry"] = False
with snowflake.connector.connect(
**config
) as conn, capture_sf_telemetry.patch_connection(conn, False) as telemetry_test:
conn._log_telemetry_imported_packages()
assert len(telemetry_test.records) == 0

0 comments on commit 4ccba6a

Please sign in to comment.