From 46ee1c2c8d3d0e5793f42fd10bcd80150caa538b Mon Sep 17 00:00:00 2001 From: Akash Sharma <35839624+Adaverse@users.noreply.github.com> Date: Wed, 28 Jun 2023 04:30:11 +0530 Subject: [PATCH] Fix where account url is build if not provided using login (account name) (#32082) --- .../providers/microsoft/azure/hooks/wasb.py | 69 ++++++++++--------- .../connections/wasb.rst | 22 ++++-- .../microsoft/azure/hooks/test_wasb.py | 53 +++++++++++++- 3 files changed, 102 insertions(+), 42 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index fbad1627be58..c1d615810b96 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -126,7 +126,7 @@ def get_connection_form_widgets() -> dict[str, Any]: def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour.""" return { - "hidden_fields": ["schema", "port", "extra"], + "hidden_fields": ["schema", "port"], "relabeling": { "login": "Blob Storage Login (optional)", "password": "Blob Storage Key (optional)", @@ -140,6 +140,7 @@ def get_ui_field_behaviour() -> dict[str, Any]: "tenant_id": "tenant", "shared_access_key": "shared access key", "sas_token": "account url or token", + "extra": "additional options for use with ClientSecretCredential or DefaultAzureCredential", }, } @@ -176,22 +177,11 @@ def get_conn(self) -> BlobServiceClient: extra = conn.extra_dejson or {} client_secret_auth_config = extra.pop("client_secret_auth_config", {}) - if self.public_read: - # Here we use anonymous public read - # more info - # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources - return BlobServiceClient(account_url=conn.host, **extra) - connection_string = self._get_field(extra, "connection_string") if connection_string: # connection_string auth takes priority return BlobServiceClient.from_connection_string(connection_string, **extra) - shared_access_key = self._get_field(extra, "shared_access_key") - if shared_access_key: - # using shared access key - return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra) - tenant = self._get_field(extra, "tenant_id") if tenant: # use Active Directory auth @@ -200,14 +190,25 @@ def get_conn(self) -> BlobServiceClient: token_credential = ClientSecretCredential(tenant, app_id, app_secret, **client_secret_auth_config) return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra) + account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/" + + if self.public_read: + # Here we use anonymous public read + # more info + # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources + return BlobServiceClient(account_url=account_url, **extra) + + shared_access_key = self._get_field(extra, "shared_access_key") + if shared_access_key: + # using shared access key + return BlobServiceClient(account_url=account_url, credential=shared_access_key, **extra) + sas_token = self._get_field(extra, "sas_token") if sas_token: if sas_token.startswith("https"): return BlobServiceClient(account_url=sas_token, **extra) else: - return BlobServiceClient( - account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra - ) + return BlobServiceClient(account_url=f"{account_url}/{sas_token}", **extra) # Fall back to old auth (password) or use managed identity if not provided. credential = conn.password @@ -215,7 +216,7 @@ def get_conn(self) -> BlobServiceClient: credential = DefaultAzureCredential() self.log.info("Using DefaultAzureCredential as credential") return BlobServiceClient( - account_url=f"https://{conn.login}.blob.core.windows.net/", + account_url=account_url, credential=credential, **extra, ) @@ -545,13 +546,6 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: extra = conn.extra_dejson or {} client_secret_auth_config = extra.pop("client_secret_auth_config", {}) - if self.public_read: - # Here we use anonymous public read - # more info - # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources - self.blob_service_client = AsyncBlobServiceClient(account_url=conn.host, **extra) - return self.blob_service_client - connection_string = self._get_field(extra, "connection_string") if connection_string: # connection_string auth takes priority @@ -560,14 +554,6 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: ) return self.blob_service_client - shared_access_key = self._get_field(extra, "shared_access_key") - if shared_access_key: - # using shared access key - self.blob_service_client = AsyncBlobServiceClient( - account_url=conn.host, credential=shared_access_key, **extra - ) - return self.blob_service_client - tenant = self._get_field(extra, "tenant_id") if tenant: # use Active Directory auth @@ -581,13 +567,30 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: ) return self.blob_service_client + account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/" + + if self.public_read: + # Here we use anonymous public read + # more info + # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources + self.blob_service_client = AsyncBlobServiceClient(account_url=account_url, **extra) + return self.blob_service_client + + shared_access_key = self._get_field(extra, "shared_access_key") + if shared_access_key: + # using shared access key + self.blob_service_client = AsyncBlobServiceClient( + account_url=account_url, credential=shared_access_key, **extra + ) + return self.blob_service_client + sas_token = self._get_field(extra, "sas_token") if sas_token: if sas_token.startswith("https"): self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra) else: self.blob_service_client = AsyncBlobServiceClient( - account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra + account_url=f"{account_url}/{sas_token}", **extra ) return self.blob_service_client @@ -597,7 +600,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: credential = AsyncDefaultAzureCredential() self.log.info("Using DefaultAzureCredential as credential") self.blob_service_client = AsyncBlobServiceClient( - account_url=f"https://{conn.login}.blob.core.windows.net/", + account_url=account_url, credential=credential, **extra, ) diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst index ce057e1592ab..8efdeef3628f 100644 --- a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst +++ b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst @@ -54,23 +54,31 @@ Configuring the Connection -------------------------- Login (optional) - Specify the login used for azure blob storage. For use with Shared Key Credential and SAS Token authentication. + Specify the login used for Azure Blob Storage. Strictly needed for Active Directory (token) authentication as Service principle credential. Optional for the rest if host (account url) is specified. Password (optional) - Specify the password used for azure blob storage. For use with + Specify the password used for Azure Blob Storage. For use with Active Directory (token credential) and shared key authentication. Host (optional) - Specify the account url for anonymous public read, Active Directory, shared access key authentication. + Specify the account url for Azure Blob Storage. Strictly needed for Active Directory (token) authentication as Service principle credential. Optional for the rest if login (account name) is specified. + +Blob Storage Connection String (optional) + Connection string for use with connection string authentication. + +Blob Storage Shared Access Key (optional) + Specify the shared access key. Needed only for shared access key authentication. + +SAS Token (optional) + SAS Token for use with SAS Token authentication. + +Tenant Id (Active Directory Auth) (optional) + Specify the tenant to use. Required only for Active Directory (token) authentication. Extra (optional) Specify the extra parameters (as json dictionary) that can be used in Azure connection. The following parameters are all optional: - * ``tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication. - * ``shared_access_key``: Specify the shared access key. Needed for shared access key authentication. - * ``connection_string``: Connection string for use with connection string authentication. - * ``sas_token``: SAS Token for use with SAS Token authentication. * ``client_secret_auth_config``: Extra config to pass while authenticating as a service principal using `ClientSecretCredential `_ When specifying the connection in environment variable you should specify diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 6d837a2fa235..464db0f39f43 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -40,16 +40,19 @@ class TestWasbHook: def setup_method(self): - db.merge_conn(Connection(conn_id="wasb_test_key", conn_type="wasb", login="login", password="key")) + self.login = "login" + self.wasb_test_key = "wasb_test_key" self.connection_type = "wasb" self.connection_string_id = "azure_test_connection_string" self.shared_key_conn_id = "azure_shared_key_test" + self.shared_key_conn_id_without_host = "azure_shared_key_test_wihout_host" self.ad_conn_id = "azure_AD_test" self.sas_conn_id = "sas_token_id" self.extra__wasb__sas_conn_id = "extra__sas_token_id" self.http_sas_conn_id = "http_sas_token_id" self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id" self.public_read_conn_id = "pub_read_id" + self.public_read_conn_id_without_host = "pub_read_id_without_host" self.managed_identity_conn_id = "managed_identity" self.authority = "https://test_authority.com" @@ -60,6 +63,14 @@ def setup_method(self): "authority": self.authority, } + db.merge_conn( + Connection( + conn_id=self.wasb_test_key, + conn_type=self.connection_type, + login=self.login, + password="key", + ) + ) db.merge_conn( Connection( conn_id=self.public_read_conn_id, @@ -68,7 +79,14 @@ def setup_method(self): extra=json.dumps({"proxies": self.proxies}), ) ) - + db.merge_conn( + Connection( + conn_id=self.public_read_conn_id_without_host, + conn_type=self.connection_type, + login=self.login, + extra=json.dumps({"proxies": self.proxies}), + ) + ) db.merge_conn( Connection( conn_id=self.connection_string_id, @@ -84,6 +102,14 @@ def setup_method(self): extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}), ) ) + db.merge_conn( + Connection( + conn_id=self.shared_key_conn_id_without_host, + conn_type=self.connection_type, + login=self.login, + extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}), + ) + ) db.merge_conn( Connection( conn_id=self.ad_conn_id, @@ -111,6 +137,7 @@ def setup_method(self): Connection( conn_id=self.sas_conn_id, conn_type=self.connection_type, + login=self.login, extra=json.dumps({"sas_token": "token", "proxies": self.proxies}), ) ) @@ -118,6 +145,7 @@ def setup_method(self): Connection( conn_id=self.extra__wasb__sas_conn_id, conn_type=self.connection_type, + login=self.login, extra=json.dumps({"extra__wasb__sas_token": "token", "proxies": self.proxies}), ) ) @@ -171,6 +199,23 @@ def test_azure_directory_connection(self): assert isinstance(hook.get_conn(), BlobServiceClient) assert isinstance(hook.get_conn().credential, ClientSecretCredential) + @pytest.mark.parametrize( + argnames="conn_id_str", + argvalues=[ + "wasb_test_key", + "shared_key_conn_id_without_host", + "public_read_conn_id_without_host", + ], + ) + def test_account_url_without_host(self, conn_id_str): + conn_id = self.__getattribute__(conn_id_str) + hook = WasbHook(wasb_conn_id=conn_id) + hook_conn = hook.get_connection(hook.conn_id) + conn = hook.get_conn() + assert conn.url.startswith("https://") + assert conn.url.__contains__(hook_conn.login) + assert conn.url.endswith(".blob.core.windows.net/") + @pytest.mark.parametrize( argnames="conn_id_str, extra_key", argvalues=[ @@ -187,6 +232,9 @@ def test_sas_token_connection(self, conn_id_str, extra_key): hook_conn = hook.get_connection(hook.conn_id) sas_token = hook_conn.extra_dejson[extra_key] assert isinstance(conn, BlobServiceClient) + assert conn.url.startswith("https://") + if hook_conn.login: + assert conn.url.__contains__(hook_conn.login) assert conn.url.endswith(sas_token + "/") @pytest.mark.parametrize( @@ -459,4 +507,5 @@ def test___ensure_prefixes(self): "extra__wasb__tenant_id", "extra__wasb__shared_access_key", "extra__wasb__sas_token", + "extra", ]