Skip to content

Commit

Permalink
chore: add application name (#1245)
Browse files Browse the repository at this point in the history
* chore: add application name

* fix parentheses to dataframe object

---------

Co-authored-by: Mo Sriha <[email protected]>
  • Loading branch information
wochinge and medsriha authored Dec 13, 2024
1 parent e898557 commit 8a435d9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
db_schema: Optional[str] = None,
warehouse: Optional[str] = None,
login_timeout: Optional[int] = None,
application_name: Optional[str] = None,
) -> None:
"""
:param user: User's login.
Expand All @@ -82,6 +83,7 @@ def __init__(
:param db_schema: Name of the schema to use.
:param warehouse: Name of the warehouse to use.
:param login_timeout: Timeout in seconds for login. By default, 60 seconds.
:param application_name: Name of the application to use when connecting to Snowflake.
"""

self.user = user
Expand All @@ -91,6 +93,7 @@ def __init__(
self.db_schema = db_schema
self.warehouse = warehouse
self.login_timeout = login_timeout or 60
self.application_name = application_name

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -108,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]:
db_schema=self.db_schema,
warehouse=self.warehouse,
login_timeout=self.login_timeout,
application_name=self.application_name,
)

@classmethod
Expand Down Expand Up @@ -285,6 +289,7 @@ def _fetch_data(
"schema": self.db_schema,
"warehouse": self.warehouse,
"login_timeout": self.login_timeout,
**({"application": self.application_name} if self.application_name else {}),
}
)
if conn is None:
Expand Down Expand Up @@ -325,7 +330,7 @@ def run(self, query: str) -> Dict[str, Any]:
if not query:
logger.error("Provide a valid SQL query.")
return {
"dataframe": pd.DataFrame,
"dataframe": pd.DataFrame(),
"table": "",
}
else:
Expand Down
62 changes: 61 additions & 1 deletion integrations/snowflake/tests/test_snowflake_table_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,64 @@ def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: Snowflake

assert result["dataframe"].equals(expected["dataframe"])
assert result["table"] == expected["table"]
mock_connect.assert_called_once_with(
user="test_user",
account="test_account",
password="test-api-key",
database="test_database",
schema="test_schema",
warehouse="test_warehouse",
login_timeout=30,
)

@patch(
"haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect"
)
def test_run_with_application_name(
self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever
) -> None:
snowflake_table_retriever.application_name = "test_application"
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_col1 = MagicMock()
mock_col2 = MagicMock()
mock_cursor.fetchall.side_effect = [
[("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles
[
(
"DATETIME",
"SELECT",
"TABLE",
"locations",
"ROLE",
"ROLE_NAME",
"GRANT_OPTION",
"GRANTED_BY",
)
],
]
mock_col1.name = "City"
mock_col2.name = "State"
mock_cursor.description = [mock_col1, mock_col2]

mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")]
mock_conn.cursor.return_value = mock_cursor
mock_connect.return_value = mock_conn

query = "SELECT * FROM locations"

snowflake_table_retriever.run(query=query)

mock_connect.assert_called_once_with(
user="test_user",
account="test_account",
password="test-api-key",
database="test_database",
schema="test_schema",
warehouse="test_warehouse",
login_timeout=30,
application="test_application",
)

@pytest.fixture
def mock_chat_completion(self) -> Generator:
Expand Down Expand Up @@ -494,6 +552,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None:
"db_schema": "test_schema",
"warehouse": "test_warehouse",
"login_timeout": 30,
"application_name": None,
},
}

Expand All @@ -508,6 +567,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None:
db_schema="SMALL_TOWNS",
warehouse="COMPUTE_WH",
login_timeout=30,
application_name="test_application",
)

data = component.to_dict()
Expand All @@ -529,6 +589,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None:
"db_schema": "SMALL_TOWNS",
"warehouse": "COMPUTE_WH",
"login_timeout": 30,
"application_name": "test_application",
},
}

Expand Down Expand Up @@ -605,7 +666,6 @@ def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -
assert result.empty

def test_serialization_deserialization_pipeline(self) -> None:

pipeline = Pipeline()
pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account"))
pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}"))
Expand Down

0 comments on commit 8a435d9

Please sign in to comment.