Skip to content

Commit

Permalink
Add azure transport params for get_record_handler_data
Browse files Browse the repository at this point in the history
  • Loading branch information
mckornfield committed Aug 3, 2023
1 parent 51a064d commit 7526a18
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion src/gretel_trainer/relational/sdk_extras.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import shutil
from contextlib import suppress
from pathlib import Path
Expand All @@ -14,6 +15,15 @@

from gretel_trainer.relational.core import MultiTableException

try:
from azure.identity import DefaultAzureCredential
except ImportError: # pragma: no cover
DefaultAzureCredential = None
try:
from azure.storage.blob import BlobServiceClient
except ImportError: # pragma: no cover
BlobServiceClient = None

logger = logging.getLogger(__name__)

MAX_PROJECT_ARTIFACTS = 50
Expand Down Expand Up @@ -79,8 +89,30 @@ def sqs_score_from_full_report(self, report: dict[str, Any]) -> Optional[int]:
if field_dict["field"] == "synthetic_data_quality_score":
return field_dict["value"]

def _get_azure_blob_srv_client(self) -> Optional[BlobServiceClient]:
if (storage_account := os.getenv("OAUTH_STORAGE_ACCOUNT_NAME")) is not None:
oauth_url = "https://{}.blob.core.windows.net".format(storage_account)
return BlobServiceClient(
account_url=oauth_url, credential=DefaultAzureCredential()
)

if (connect_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING")) is not None:
return BlobServiceClient.from_connection_string(connect_str)

def _get_transport_params(self, link: str) -> dict:
"""Returns a set of transport params that are suitable for passing
into calls to ``smart_open.open``.
"""
client = None
if link.startswith("azure"):
client = self._get_azure_blob_srv_client()
return {"client": client} if client else {}

def get_record_handler_data(self, record_handler: RecordHandler) -> pd.DataFrame:
with smart_open.open(record_handler.get_artifact_link("data"), "rb") as data:
link = record_handler.get_artifact_link("data")
with smart_open.open(
link, "rb", transport_params=self._get_transport_params(link)
) as data:
return pd.read_csv(data)

def start_job_if_possible(
Expand Down

0 comments on commit 7526a18

Please sign in to comment.