Skip to content

Commit

Permalink
Fix where account url is build if not provided using login (account n…
Browse files Browse the repository at this point in the history
…ame) (#32082)
  • Loading branch information
Adaverse authored Jun 27, 2023
1 parent 0bc689e commit 46ee1c2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 42 deletions.
69 changes: 36 additions & 33 deletions airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_connection_form_widgets() -> dict[str, Any]:
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
"hidden_fields": ["schema", "port"],
"relabeling": {
"login": "Blob Storage Login (optional)",
"password": "Blob Storage Key (optional)",
Expand All @@ -140,6 +140,7 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"tenant_id": "tenant",
"shared_access_key": "shared access key",
"sas_token": "account url or token",
"extra": "additional options for use with ClientSecretCredential or DefaultAzureCredential",
},
}

Expand Down Expand Up @@ -176,22 +177,11 @@ def get_conn(self) -> BlobServiceClient:
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
return BlobServiceClient(account_url=conn.host, **extra)

connection_string = self._get_field(extra, "connection_string")
if connection_string:
# connection_string auth takes priority
return BlobServiceClient.from_connection_string(connection_string, **extra)

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra)

tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
Expand All @@ -200,22 +190,33 @@ def get_conn(self) -> BlobServiceClient:
token_credential = ClientSecretCredential(tenant, app_id, app_secret, **client_secret_auth_config)
return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)

account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
return BlobServiceClient(account_url=account_url, **extra)

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
return BlobServiceClient(account_url=account_url, credential=shared_access_key, **extra)

sas_token = self._get_field(extra, "sas_token")
if sas_token:
if sas_token.startswith("https"):
return BlobServiceClient(account_url=sas_token, **extra)
else:
return BlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra
)
return BlobServiceClient(account_url=f"{account_url}/{sas_token}", **extra)

# Fall back to old auth (password) or use managed identity if not provided.
credential = conn.password
if not credential:
credential = DefaultAzureCredential()
self.log.info("Using DefaultAzureCredential as credential")
return BlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/",
account_url=account_url,
credential=credential,
**extra,
)
Expand Down Expand Up @@ -545,13 +546,6 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
self.blob_service_client = AsyncBlobServiceClient(account_url=conn.host, **extra)
return self.blob_service_client

connection_string = self._get_field(extra, "connection_string")
if connection_string:
# connection_string auth takes priority
Expand All @@ -560,14 +554,6 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
)
return self.blob_service_client

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
self.blob_service_client = AsyncBlobServiceClient(
account_url=conn.host, credential=shared_access_key, **extra
)
return self.blob_service_client

tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
Expand All @@ -581,13 +567,30 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
)
return self.blob_service_client

account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
self.blob_service_client = AsyncBlobServiceClient(account_url=account_url, **extra)
return self.blob_service_client

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
self.blob_service_client = AsyncBlobServiceClient(
account_url=account_url, credential=shared_access_key, **extra
)
return self.blob_service_client

sas_token = self._get_field(extra, "sas_token")
if sas_token:
if sas_token.startswith("https"):
self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra)
else:
self.blob_service_client = AsyncBlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra
account_url=f"{account_url}/{sas_token}", **extra
)
return self.blob_service_client

Expand All @@ -597,7 +600,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
credential = AsyncDefaultAzureCredential()
self.log.info("Using DefaultAzureCredential as credential")
self.blob_service_client = AsyncBlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/",
account_url=account_url,
credential=credential,
**extra,
)
Expand Down
22 changes: 15 additions & 7 deletions docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,31 @@ Configuring the Connection
--------------------------

Login (optional)
Specify the login used for azure blob storage. For use with Shared Key Credential and SAS Token authentication.
Specify the login used for Azure Blob Storage. Strictly needed for Active Directory (token) authentication as Service principle credential. Optional for the rest if host (account url) is specified.

Password (optional)
Specify the password used for azure blob storage. For use with
Specify the password used for Azure Blob Storage. For use with
Active Directory (token credential) and shared key authentication.

Host (optional)
Specify the account url for anonymous public read, Active Directory, shared access key authentication.
Specify the account url for Azure Blob Storage. Strictly needed for Active Directory (token) authentication as Service principle credential. Optional for the rest if login (account name) is specified.

Blob Storage Connection String (optional)
Connection string for use with connection string authentication.

Blob Storage Shared Access Key (optional)
Specify the shared access key. Needed only for shared access key authentication.

SAS Token (optional)
SAS Token for use with SAS Token authentication.

Tenant Id (Active Directory Auth) (optional)
Specify the tenant to use. Required only for Active Directory (token) authentication.

Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in Azure connection.
The following parameters are all optional:

* ``tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication.
* ``shared_access_key``: Specify the shared access key. Needed for shared access key authentication.
* ``connection_string``: Connection string for use with connection string authentication.
* ``sas_token``: SAS Token for use with SAS Token authentication.
* ``client_secret_auth_config``: Extra config to pass while authenticating as a service principal using `ClientSecretCredential <https://learn.microsoft.com/en-in/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python>`_

When specifying the connection in environment variable you should specify
Expand Down
53 changes: 51 additions & 2 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@

class TestWasbHook:
def setup_method(self):
db.merge_conn(Connection(conn_id="wasb_test_key", conn_type="wasb", login="login", password="key"))
self.login = "login"
self.wasb_test_key = "wasb_test_key"
self.connection_type = "wasb"
self.connection_string_id = "azure_test_connection_string"
self.shared_key_conn_id = "azure_shared_key_test"
self.shared_key_conn_id_without_host = "azure_shared_key_test_wihout_host"
self.ad_conn_id = "azure_AD_test"
self.sas_conn_id = "sas_token_id"
self.extra__wasb__sas_conn_id = "extra__sas_token_id"
self.http_sas_conn_id = "http_sas_token_id"
self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id"
self.public_read_conn_id = "pub_read_id"
self.public_read_conn_id_without_host = "pub_read_id_without_host"
self.managed_identity_conn_id = "managed_identity"
self.authority = "https://test_authority.com"

Expand All @@ -60,6 +63,14 @@ def setup_method(self):
"authority": self.authority,
}

db.merge_conn(
Connection(
conn_id=self.wasb_test_key,
conn_type=self.connection_type,
login=self.login,
password="key",
)
)
db.merge_conn(
Connection(
conn_id=self.public_read_conn_id,
Expand All @@ -68,7 +79,14 @@ def setup_method(self):
extra=json.dumps({"proxies": self.proxies}),
)
)

db.merge_conn(
Connection(
conn_id=self.public_read_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.connection_string_id,
Expand All @@ -84,6 +102,14 @@ def setup_method(self):
extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.shared_key_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.ad_conn_id,
Expand Down Expand Up @@ -111,13 +137,15 @@ def setup_method(self):
Connection(
conn_id=self.sas_conn_id,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"sas_token": "token", "proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.extra__wasb__sas_conn_id,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"extra__wasb__sas_token": "token", "proxies": self.proxies}),
)
)
Expand Down Expand Up @@ -171,6 +199,23 @@ def test_azure_directory_connection(self):
assert isinstance(hook.get_conn(), BlobServiceClient)
assert isinstance(hook.get_conn().credential, ClientSecretCredential)

@pytest.mark.parametrize(
argnames="conn_id_str",
argvalues=[
"wasb_test_key",
"shared_key_conn_id_without_host",
"public_read_conn_id_without_host",
],
)
def test_account_url_without_host(self, conn_id_str):
conn_id = self.__getattribute__(conn_id_str)
hook = WasbHook(wasb_conn_id=conn_id)
hook_conn = hook.get_connection(hook.conn_id)
conn = hook.get_conn()
assert conn.url.startswith("https://")
assert conn.url.__contains__(hook_conn.login)
assert conn.url.endswith(".blob.core.windows.net/")

@pytest.mark.parametrize(
argnames="conn_id_str, extra_key",
argvalues=[
Expand All @@ -187,6 +232,9 @@ def test_sas_token_connection(self, conn_id_str, extra_key):
hook_conn = hook.get_connection(hook.conn_id)
sas_token = hook_conn.extra_dejson[extra_key]
assert isinstance(conn, BlobServiceClient)
assert conn.url.startswith("https://")
if hook_conn.login:
assert conn.url.__contains__(hook_conn.login)
assert conn.url.endswith(sas_token + "/")

@pytest.mark.parametrize(
Expand Down Expand Up @@ -459,4 +507,5 @@ def test___ensure_prefixes(self):
"extra__wasb__tenant_id",
"extra__wasb__shared_access_key",
"extra__wasb__sas_token",
"extra",
]

0 comments on commit 46ee1c2

Please sign in to comment.