Skip to content

Commit

Permalink
SNOW-891470: Refactor reading default connection into a function (#1722)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu authored Sep 6, 2023
1 parent 86c0a11 commit 6a2dfd0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
13 changes: 13 additions & 0 deletions src/snowflake/connector/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from snowflake.connector.errors import (
ConfigManagerError,
ConfigSourceError,
Error,
MissingConfigOptionError,
)

Expand Down Expand Up @@ -457,6 +458,18 @@ def __getitem__(self, name: str) -> ConfigOption | ConfigManager:
)


def _get_default_connection_params() -> dict[str, Any]:
def_connection_name = CONFIG_MANAGER["default_connection_name"]
connections = CONFIG_MANAGER["connections"]
if def_connection_name not in connections:
raise Error(
f"Default connection with name '{def_connection_name}' "
"cannot be found, known ones are "
f"{list(connections.keys())}"
)
return {**connections[def_connection_name]}


def __getattr__(name):
if name == "CONFIG_PARSER":
warnings.warn(
Expand Down
12 changes: 2 additions & 10 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .auth.idtoken import AuthByIdToken
from .bind_upload_agent import BindUploadError
from .compat import IS_LINUX, IS_WINDOWS, quote, urlencode
from .config_manager import CONFIG_MANAGER
from .config_manager import CONFIG_MANAGER, _get_default_connection_params
from .connection_diagnostic import ConnectionDiagnostic
from .constants import (
ENV_VAR_PARTNER,
Expand Down Expand Up @@ -367,15 +367,7 @@ def __init__(
kwargs = {**connections[connection_name], **kwargs}
elif is_kwargs_empty:
# connection_name is None and kwargs was empty when called
def_connection_name = CONFIG_MANAGER["default_connection_name"]
connections = CONFIG_MANAGER["connections"]
if def_connection_name not in connections:
raise Error(
f"Default connection with name '{def_connection_name}' "
"cannot be found, known ones are "
f"{list(connections.keys())}"
)
kwargs = {**connections[def_connection_name]}
kwargs = _get_default_connection_params()
self.__set_error_attributes()
self.connect(**kwargs)
self._telemetry = TelemetryClient(self._rest)
Expand Down

0 comments on commit 6a2dfd0

Please sign in to comment.