Skip to content

Commit

Permalink
Refactor account_url use in WasbHook (#32980)
Browse files Browse the repository at this point in the history
* Refactor account_url use in WasbHook

This PR moves the account_url setting to one place.
Tested this by making connection to azure using the different methods, however, I was not able to connect using
the tenant_id in the extra field. This looks like a bug because ClientSecretCredential is not among the credentials
to use in BlobServiceClient. The credentials to use include AzureNamedKeyCredential,AzureSasCredential,AsyncTokenCredential.
So this will need special debugging.

* fixup! Refactor account_url use in WasbHook
  • Loading branch information
ephraimbuddy authored Aug 2, 2023
1 parent 5f5293f commit df74553
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
32 changes: 17 additions & 15 deletions airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"relabeling": {
"login": "Blob Storage Login (optional)",
"password": "Blob Storage Key (optional)",
"host": "Account Name (Active Directory Auth)",
"host": "Account URL (Active Directory Auth)",
},
"placeholders": {
"login": "account name",
Expand All @@ -154,7 +154,7 @@ def __init__(
super().__init__()
self.conn_id = wasb_conn_id
self.public_read = public_read
self.blob_service_client = self.get_conn()
self.blob_service_client: BlobServiceClient = self.get_conn()

logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
try:
Expand Down Expand Up @@ -184,15 +184,19 @@ def get_conn(self) -> BlobServiceClient:
# connection_string auth takes priority
return BlobServiceClient.from_connection_string(connection_string, **extra)

account_url = (
conn.host
if conn.host and conn.host.startswith("https://")
else f"https://{conn.login}.blob.core.windows.net/"
)

tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
app_id = conn.login
app_secret = conn.password
token_credential = ClientSecretCredential(tenant, app_id, app_secret, **client_secret_auth_config)
return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)

account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
return BlobServiceClient(account_url=account_url, credential=token_credential, **extra)

if self.public_read:
# Here we use anonymous public read
Expand All @@ -210,19 +214,13 @@ def get_conn(self) -> BlobServiceClient:
if sas_token.startswith("https"):
return BlobServiceClient(account_url=sas_token, **extra)
else:
if not account_url.startswith("https://"):
# TODO: require url in the host field in the next major version?
account_url = f"https://{conn.login}.blob.core.windows.net"
return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra)

# Fall back to old auth (password) or use managed identity if not provided.
credential = conn.password
if not credential:
credential = DefaultAzureCredential()
self.log.info("Using DefaultAzureCredential as credential")
if not account_url.startswith("https://"):
# TODO: require url in the host field in the next major version?
account_url = f"https://{conn.login}.blob.core.windows.net/"
return BlobServiceClient(
account_url=account_url,
credential=credential,
Expand Down Expand Up @@ -589,6 +587,12 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
)
return self.blob_service_client

account_url = (
conn.host
if conn.host and conn.host.startswith("https://")
else f"https://{conn.login}.blob.core.windows.net/"
)

tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
Expand All @@ -598,12 +602,10 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
tenant, app_id, app_secret, **client_secret_auth_config
)
self.blob_service_client = AsyncBlobServiceClient(
account_url=conn.host, credential=token_credential, **extra # type:ignore[arg-type]
account_url=account_url, credential=token_credential, **extra # type:ignore[arg-type]
)
return self.blob_service_client

account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"

if self.public_read:
# Here we use anonymous public read
# more info
Expand All @@ -625,7 +627,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra)
else:
self.blob_service_client = AsyncBlobServiceClient(
account_url=f"{account_url}/{sas_token}", **extra
account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra
)
return self.blob_service_client

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_azure_directory_connection(self, mock_get_conn, mock_credential, mock_b
authority=self.client_secret_auth_config["authority"],
)
mock_blob_service_client.assert_called_once_with(
account_url=conn.host,
account_url=f"https://{conn.login}.blob.core.windows.net/",
credential=mock_credential.return_value,
tenant_id=conn.extra_dejson["tenant_id"],
proxies=conn.extra_dejson["proxies"],
Expand Down

0 comments on commit df74553

Please sign in to comment.