Skip to content

Commit

Permalink
chore: add application name
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Dec 12, 2024
1 parent e898557 commit 72634ca
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
db_schema: Optional[str] = None,
warehouse: Optional[str] = None,
login_timeout: Optional[int] = None,
application_name: Optional[str] = None,
) -> None:
"""
:param user: User's login.
Expand All @@ -82,6 +83,7 @@ def __init__(
:param db_schema: Name of the schema to use.
:param warehouse: Name of the warehouse to use.
:param login_timeout: Timeout in seconds for login. By default, 60 seconds.
:param application_name: Name of the application to use when connecting to Snowflake.
"""

self.user = user
Expand All @@ -91,6 +93,7 @@ def __init__(
self.db_schema = db_schema
self.warehouse = warehouse
self.login_timeout = login_timeout or 60
self.application_name = application_name

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -108,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]:
db_schema=self.db_schema,
warehouse=self.warehouse,
login_timeout=self.login_timeout,
application_name=self.application_name,
)

@classmethod
Expand Down Expand Up @@ -285,6 +289,7 @@ def _fetch_data(
"schema": self.db_schema,
"warehouse": self.warehouse,
"login_timeout": self.login_timeout,
**({"application": self.application_name} if self.application_name else {}),
}
)
if conn is None:
Expand Down
82 changes: 67 additions & 15 deletions integrations/snowflake/tests/test_snowflake_table_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,15 @@ def test_extract_multiple_table_names(self, snowflake_table_retriever: Snowflake
def test_extract_multiple_db_schema_from_table_names(
self, snowflake_table_retriever: SnowflakeTableRetriever
) -> None:
assert (
snowflake_table_retriever._extract_table_names(
query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a
assert snowflake_table_retriever._extract_table_names(
query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a
FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id"""
)
== ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"]
or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B"]
)
) == ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"] or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B"]
# No database
assert (
snowflake_table_retriever._extract_table_names(
query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a
assert snowflake_table_retriever._extract_table_names(
query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a
FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id"""
)
== ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"]
or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"]
)
) == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"]

@patch(
"haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect"
Expand Down Expand Up @@ -352,6 +344,64 @@ def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: Snowflake

assert result["dataframe"].equals(expected["dataframe"])
assert result["table"] == expected["table"]
mock_connect.assert_called_once_with(
user="test_user",
account="test_account",
password="test-api-key",
database="test_database",
schema="test_schema",
warehouse="test_warehouse",
login_timeout=30,
)

@patch(
"haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect"
)
def test_run_with_application_name(
self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever
) -> None:
snowflake_table_retriever.application_name = "test_application"
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_col1 = MagicMock()
mock_col2 = MagicMock()
mock_cursor.fetchall.side_effect = [
[("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles
[
(
"DATETIME",
"SELECT",
"TABLE",
"locations",
"ROLE",
"ROLE_NAME",
"GRANT_OPTION",
"GRANTED_BY",
)
],
]
mock_col1.name = "City"
mock_col2.name = "State"
mock_cursor.description = [mock_col1, mock_col2]

mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")]
mock_conn.cursor.return_value = mock_cursor
mock_connect.return_value = mock_conn

query = "SELECT * FROM locations"

snowflake_table_retriever.run(query=query)

mock_connect.assert_called_once_with(
user="test_user",
account="test_account",
password="test-api-key",
database="test_database",
schema="test_schema",
warehouse="test_warehouse",
login_timeout=30,
application="test_application",
)

@pytest.fixture
def mock_chat_completion(self) -> Generator:
Expand Down Expand Up @@ -494,6 +544,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None:
"db_schema": "test_schema",
"warehouse": "test_warehouse",
"login_timeout": 30,
"application_name": None,
},
}

Expand All @@ -508,6 +559,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None:
db_schema="SMALL_TOWNS",
warehouse="COMPUTE_WH",
login_timeout=30,
application_name="test_application",
)

data = component.to_dict()
Expand All @@ -529,6 +581,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None:
"db_schema": "SMALL_TOWNS",
"warehouse": "COMPUTE_WH",
"login_timeout": 30,
"application_name": "test_application",
},
}

Expand Down Expand Up @@ -605,7 +658,6 @@ def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -
assert result.empty

def test_serialization_deserialization_pipeline(self) -> None:

pipeline = Pipeline()
pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account"))
pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}"))
Expand Down

0 comments on commit 72634ca

Please sign in to comment.