Skip to content

Commit

Permalink
Fix updating account url for WasbHook
Browse files Browse the repository at this point in the history
There are different ways users supply the hostname(account url) in azure,
sometimes the host doesn't have a urlparse.scheme but has urlparse.path e.g name.blob.windows.net
and other times, it will just be Azure ID e.g aldhjf9dads.
While working on apache#32980, I assumed that if there's no scheme, then the hostname is not valid, that's
incorrect since DNS can serve as the host.
The fix was to check if we don't have netloc and that urlparse.path does not include a dot and if it does not, use the login/account_name to construct
the account_url
  • Loading branch information
ephraimbuddy committed Aug 17, 2023
1 parent e90febc commit a3c580a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
25 changes: 15 additions & 10 deletions airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import logging
import os
from typing import Any, Union
from urllib.parse import urlparse

from asgiref.sync import sync_to_async
from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
Expand Down Expand Up @@ -152,11 +153,13 @@ def get_conn(self) -> BlobServiceClient:
# connection_string auth takes priority
return BlobServiceClient.from_connection_string(connection_string, **extra)

account_url = (
conn.host
if conn.host and conn.host.startswith("https://")
else f"https://{conn.login}.blob.core.windows.net/"
)
account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
parsed_url = urlparse(account_url)

if not parsed_url.netloc and "." not in parsed_url.path:
# if there's no netloc and no dots in the path, then user only
# provided the host ID, not the full URL or DNS name
account_url = f"https://{conn.login}.blob.core.windows.net/"

tenant = self._get_field(extra, "tenant_id")
if tenant:
Expand Down Expand Up @@ -555,11 +558,13 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
)
return self.blob_service_client

account_url = (
conn.host
if conn.host and conn.host.startswith("https://")
else f"https://{conn.login}.blob.core.windows.net/"
)
account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
parsed_url = urlparse(account_url)

if not parsed_url.netloc and "." not in parsed_url.path:
# if there's no netloc and no dots in the path, then user only
# provided the host ID, not the full URL or DNS name
account_url = f"https://{conn.login}.blob.core.windows.net/"

tenant = self._get_field(extra, "tenant_id")
if tenant:
Expand Down
27 changes: 27 additions & 0 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,33 @@ def test_extra_client_secret_auth_config_ad_connection(self, mock_get_conn):
conn = hook.get_conn()
assert conn.credential._authority == self.authority

@pytest.mark.parametrize(
"provided_host, expected_host",
[
(
"https://testaccountname.blob.core.windows.net",
"https://testaccountname.blob.core.windows.net",
),
("testhost", "https://accountlogin.blob.core.windows.net/"),
("testhost.dns", "testhost.dns"),
("testhost.blob.net", "testhost.blob.net"),
],
)
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
def test_proper_account_url_update(
self, mock_get_conn, mock_blob_service_client, provided_host, expected_host
):
mock_get_conn.return_value = Connection(
conn_id="test_conn",
conn_type=self.connection_type,
password="testpass",
login="accountlogin",
host=provided_host,
)
WasbHook(wasb_conn_id=self.shared_key_conn_id)
mock_blob_service_client.assert_called_once_with(account_url=expected_host, credential="testpass")

@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
def test_check_for_blob(self, mock_get_conn, mock_service):
Expand Down

0 comments on commit a3c580a

Please sign in to comment.