From 7526a18be53984536752fc88d7f3564785ea7e22 Mon Sep 17 00:00:00 2001 From: Matt Kornfield Date: Thu, 3 Aug 2023 13:54:28 -0700 Subject: [PATCH] Add azure transport params for get_record_handler_data --- src/gretel_trainer/relational/sdk_extras.py | 34 ++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py index e9550bad..40cbd0fc 100644 --- a/src/gretel_trainer/relational/sdk_extras.py +++ b/src/gretel_trainer/relational/sdk_extras.py @@ -1,4 +1,5 @@ import logging +import os import shutil from contextlib import suppress from pathlib import Path @@ -14,6 +15,15 @@ from gretel_trainer.relational.core import MultiTableException +try: + from azure.identity import DefaultAzureCredential +except ImportError: # pragma: no cover + DefaultAzureCredential = None +try: + from azure.storage.blob import BlobServiceClient +except ImportError: # pragma: no cover + BlobServiceClient = None + logger = logging.getLogger(__name__) MAX_PROJECT_ARTIFACTS = 50 @@ -79,8 +89,30 @@ def sqs_score_from_full_report(self, report: dict[str, Any]) -> Optional[int]: if field_dict["field"] == "synthetic_data_quality_score": return field_dict["value"] + def _get_azure_blob_srv_client(self) -> Optional[BlobServiceClient]: + if (storage_account := os.getenv("OAUTH_STORAGE_ACCOUNT_NAME")) is not None: + oauth_url = "https://{}.blob.core.windows.net".format(storage_account) + return BlobServiceClient( + account_url=oauth_url, credential=DefaultAzureCredential() + ) + + if (connect_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING")) is not None: + return BlobServiceClient.from_connection_string(connect_str) + + def _get_transport_params(self, link: str) -> dict: + """Returns a set of transport params that are suitable for passing + into calls to ``smart_open.open``. + """ + client = None + if link.startswith("azure"): + client = self._get_azure_blob_srv_client() + return {"client": client} if client else {} + def get_record_handler_data(self, record_handler: RecordHandler) -> pd.DataFrame: - with smart_open.open(record_handler.get_artifact_link("data"), "rb") as data: + link = record_handler.get_artifact_link("data") + with smart_open.open( + link, "rb", transport_params=self._get_transport_params(link) + ) as data: return pd.read_csv(data) def start_job_if_possible(