diff --git a/airflow/providers/microsoft/azure/hooks/asb.py b/airflow/providers/microsoft/azure/hooks/asb.py new file mode 100644 index 0000000000000..e99296b602e6f --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional, Union + +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusSender +from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient + +from airflow.hooks.base import BaseHook + + +class BaseAzureServiceBusHook(BaseHook): + """ + BaseAzureServiceBusHook class to create session and create connection using connection string + + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + conn_name_attr = 'azure_service_bus_conn_id' + default_conn_name = 'azure_service_bus_default' + conn_type = 'azure_service_bus' + hook_name = 'Azure Service Bus' + + @staticmethod + def get_ui_field_behaviour() -> Dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ['port', 'host', 'extra', 'login', 'password'], + "relabeling": {'schema': 'Connection String'}, + "placeholders": { + 'schema': 'Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey=', # noqa + }, + } + + def __init__(self, azure_service_bus_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_service_bus_conn_id + + def get_conn(self): + raise NotImplementedError + + +class AdminClientHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusAdministrationClient client + to create, update, list, and delete resources of a + Service Bus namespace. This hook uses the same Azure Service Bus client connection inherited + from the base class + """ + + def get_conn(self) -> ServiceBusAdministrationClient: + """ + Create and returns ServiceBusAdministrationClient by using the connection + string in connection details + """ + conn = self.get_connection(self.conn_id) + + connection_string: str = str(conn.schema) + return ServiceBusAdministrationClient.from_connection_string(connection_string) + + def create_queue( + self, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + ) -> QueueProperties: + """ + Create Queue by connecting to service Bus Admin client return the QueueProperties + + :param queue_name: The name of the queue or a QueueProperties with name. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + queue = service_mgmt_conn.create_queue( + queue_name, + max_delivery_count=max_delivery_count, + dead_lettering_on_message_expiration=dead_lettering_on_message_expiration, + enable_batched_operations=enable_batched_operations, + ) + return queue + + def delete_queue(self, queue_name: str) -> None: + """ + Delete the queue by queue_name in service bus namespace + + :param queue_name: The name of the queue or a QueueProperties with name. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + service_mgmt_conn.delete_queue(queue_name) + + +class MessageHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusClient and acts as a high level interface + for getting ServiceBusSender and ServiceBusReceiver. + """ + + def get_conn(self) -> ServiceBusClient: + """Create and returns ServiceBusClient by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + connection_string: str = str(conn.schema) + + self.log.info("Create and returns ServiceBusClient") + return ServiceBusClient.from_connection_string(conn_str=connection_string, logging_enable=True) + + def send_message( + self, queue_name: str, messages: Union[str, List[str]], batch_message_flag: bool = False + ): + """ + By using ServiceBusClient Send message(s) to a Service Bus Queue. By using + batch_message_flag it enables and send message as batch message + + :param queue_name: The name of the queue or a QueueProperties with name. + :param messages: Message which needs to be sent to the queue. It can be string or list of string. + :param batch_message_flag: bool flag, can be set to True if message needs to be sent as batch message. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + if not messages: + raise ValueError("Messages list cannot be empty.") + with self.get_conn() as service_bus_client, service_bus_client.get_queue_sender( + queue_name=queue_name + ) as sender: + with sender: + if isinstance(messages, str): + if not batch_message_flag: + msg = ServiceBusMessage(messages) + sender.send_messages(msg) + else: + self.send_batch_message(sender, [messages]) + else: + if not batch_message_flag: + self.send_list_messages(sender, messages) + else: + self.send_batch_message(sender, messages) + + @staticmethod + def send_list_messages(sender: ServiceBusSender, messages: List[str]): + list_messages = [ServiceBusMessage(message) for message in messages] + sender.send_messages(list_messages) # type: ignore[arg-type] + + @staticmethod + def send_batch_message(sender: ServiceBusSender, messages: List[str]): + batch_message = sender.create_message_batch() + for message in messages: + batch_message.add_message(ServiceBusMessage(message)) + sender.send_messages(batch_message) + + def receive_message( + self, queue_name, max_message_count: Optional[int] = 1, max_wait_time: Optional[float] = None + ): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + + with self.get_conn() as service_bus_client, service_bus_client.get_queue_receiver( + queue_name=queue_name + ) as receiver: + with receiver: + received_msgs = receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + receiver.complete_message(msg) diff --git a/airflow/providers/microsoft/azure/operators/asb.py b/airflow/providers/microsoft/azure/operators/asb.py new file mode 100644 index 0000000000000..f8c363c678510 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/asb.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING, List, Sequence, Union + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AzureServiceBusCreateQueueOperator(BaseOperator): + """ + Creates a Azure Service Bus queue under a Service Bus Namespace by using ServiceBusAdministrationClient + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusCreateQueueOperator` + + :param queue_name: The name of the queue. should be unique. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.max_delivery_count = max_delivery_count + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Creates Queue in Azure Service Bus namespace, by connecting to Service Bus Admin client in hook""" + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # create queue with name + queue = hook.create_queue( + self.queue_name, + self.max_delivery_count, + self.dead_lettering_on_message_expiration, + self.enable_batched_operations, + ) + self.log.info("Created Queue %s", queue.name) + + +class AzureServiceBusSendMessageOperator(BaseOperator): + """ + Send Message or batch message to the Service Bus queue + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusSendMessageOperator` + + :param queue_name: The name of the queue. should be unique. + :param message: Message which needs to be sent to the queue. It can be string or list of string. + :param batch: Its boolean flag by default it is set to False, if the message needs to be sent + as batch message it can be set to True. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + message: Union[str, List[str]], + batch: bool = False, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.batch = batch + self.message = message + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """ + Sends Message to the specific queue in Service Bus namespace, by + connecting to Service Bus client + """ + # Create the hook + hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # send message + hook.send_message(self.queue_name, self.message, self.batch) + + +class AzureServiceBusReceiveMessageOperator(BaseOperator): + """ + Receive a batch of messages at once in a specified Queue name + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusReceiveMessageOperator` + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + max_message_count: int = 10, + max_wait_time: float = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + + def execute(self, context: "Context") -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_message( + self.queue_name, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time + ) + + +class AzureServiceBusDeleteQueueOperator(BaseOperator): + """ + Deletes the Queue in the Azure Service Bus namespace + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusDeleteQueueOperator` + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = 'azure_service_bus_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: "Context") -> None: + """Delete Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete queue with name + hook.delete_queue(self.queue_name) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index e4ca164b2347d..e4373d9f090d5 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -89,6 +89,12 @@ integrations: external-doc-url: https://azure.microsoft.com/ logo: /integration-logos/azure/Microsoft-Azure.png tags: [azure] + - integration-name: Microsoft Azure Service Bus + external-doc-url: https://azure.microsoft.com/en-us/services/service-bus/ + logo: /integration-logos/azure/Service-Bus.svg + how-to-guide: + - /docs/apache-airflow-providers-microsoft-azure/operators/asb.rst + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage @@ -117,6 +123,9 @@ operators: - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.operators.data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.operators.asb sensors: - integration-name: Microsoft Azure Cosmos DB @@ -168,6 +177,9 @@ hooks: python-modules: - airflow.providers.microsoft.azure.hooks.data_factory - airflow.providers.microsoft.azure.hooks.azure_data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.hooks.asb transfers: - source-integration-name: Local @@ -204,6 +216,7 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.microsoft.azure.hooks.wasb.WasbHook - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook + - airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook @@ -230,6 +243,8 @@ connection-types: - hook-class-name: >- airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry + - hook-class-name: airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook + connection-type: azure_service_bus secrets-backends: - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/asb.rst b/docs/apache-airflow-providers-microsoft-azure/connections/asb.rst new file mode 100644 index 0000000000000..daf50d6017a50 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/connections/asb.rst @@ -0,0 +1,50 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:azure_service_bus: + +Microsoft Azure Service Bus +======================================= + +The Microsoft Azure Service Bus connection type enables the Azure Service Bus Integration. + +Authenticating to Azure Service Bus +------------------------------------ + +There are multiple ways to authenticate and authorize access to Azure Service Bus resources: +Currently Supports Shared Access Signatures (SAS). + +1. Use a `Connection String + `_ + i.e. Use connection string Field to add ``Connection String`` in the Airflow connection. + +Default Connection IDs +---------------------- + +All hooks and operators related to Microsoft Azure Service Bus use ``azure_service_bus_default`` by default. + +Configuring the Connection +-------------------------- + +Connection String + Specify the Azure Service bus connection string ID used for the initial connection. + Please find the documentation on how to generate connection string in azure service bus + `Get connection string + `_ + Use the key ``connection_string`` to pass in the Connection ID . diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst b/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst new file mode 100644 index 0000000000000..96b27a7d8017e --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst @@ -0,0 +1,107 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Azure Service Bus Operators +============================ +Azure Service Bus is a fully managed enterprise message broker with message queues and +publish-subscribe topics (in a namespace). Service Bus is used to decouple applications +and services from each other. Service Bus that perform operations on +entities, such as namespaces, queues, and topics. + +The Service Bus REST API provides operations for working with the following resources: + - Azure Resource Manager + - Service Bus service + +Azure Service Bus Queue Operators +--------------------------------- +Azure Service Bus Operators helps to interact with Azure Bus Queue based operation like Create, Delete, +Send and Receive message in Queue. + +.. _howto/operator:AzureServiceBusCreateQueueOperator: + +Create Azure Service Bus Queue +=============================== + +To create Azure service bus queue with specific Parameter you can use +:class:`~airflow.providers.microsoft.azure.operators.asb.AzureServiceBusCreateQueueOperator`. + +Below is an example of using this operator to execute an Azure Service Bus Create Queue. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_azure_service_bus.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_create_service_bus_queue] + :end-before: [END howto_operator_create_service_bus_queue] + + +.. _howto/operator:AzureServiceBusSendMessageOperator: + +Send Message to Azure Service Bus Queue +======================================= + +To Send message or list of message or batch Message to the Azure Service Bus Queue. You can use +:class:`~airflow.providers.microsoft.azure.operators.asb.AzureServiceBusSendMessageOperator`. + +Below is an example of using this operator to execute an Azure Service Bus Send Message to Queue. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_azure_service_bus.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_send_message_to_service_bus_queue] + :end-before: [END howto_operator_send_message_to_service_bus_queue] + + +.. _howto/operator:AzureServiceBusReceiveMessageOperator: + +Receive Message Azure Service Bus Queue +======================================== + +To Receive Message or list of message or Batch message message in a Queue you can use +:class:`~airflow.providers.microsoft.azure.operators.asb.AzureServiceBusReceiveMessageOperator`. + +Below is an example of using this operator to execute an Azure Service Bus Create Queue. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_azure_service_bus.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_receive_message_service_bus_queue] + :end-before: [END howto_operator_receive_message_service_bus_queue] + + +.. _howto/operator:AzureServiceBusDeleteQueueOperator: + +Delete Azure Service Bus Queue +=============================== + +To Delete the Azure service bus queue you can use +:class:`~airflow.providers.microsoft.azure.operators.asb.AzureServiceBusDeleteQueueOperator`. + +Below is an example of using this operator to execute an Azure Service Bus Delete Queue. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_azure_service_bus.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_delete_service_bus_queue] + :end-before: [END howto_operator_delete_service_bus_queue] + + +Reference +--------- + +For further information, please refer to the Microsoft documentation: + + * `Azure Service Bus Documentation `__ diff --git a/docs/integration-logos/azure/Service-Bus.svg b/docs/integration-logos/azure/Service-Bus.svg new file mode 100644 index 0000000000000..1604e04232630 --- /dev/null +++ b/docs/integration-logos/azure/Service-Bus.svg @@ -0,0 +1 @@ + diff --git a/setup.py b/setup.py index c52a682f668dc..c997322b12bf2 100644 --- a/setup.py +++ b/setup.py @@ -232,6 +232,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'azure-storage-blob>=12.7.0,<12.9.0', 'azure-storage-common>=2.1.0', 'azure-storage-file>=2.1.0', + 'azure-servicebus>=7.6.1', ] cassandra = [ 'cassandra-driver>=3.13.0', diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py b/tests/providers/microsoft/azure/hooks/test_asb.py new file mode 100644 index 0000000000000..315a31802a4d8 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_asb.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +import pytest +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusMessageBatch +from azure.servicebus.management import ServiceBusAdministrationClient + +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook + +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + + +class TestAdminClientHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.mock_conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + schema=self.connection_string, + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection") + def test_get_conn(self, mock_connection): + mock_connection.return_value = self.mock_conn + hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusAdministrationClient) + + @mock.patch('azure.servicebus.management.QueueProperties') + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn') + def test_create_queue(self, mock_sb_admin_client, mock_queue_properties): + """ + Test `create_queue` hook function with mocking connection, queue properties value and + the azure service bus `create_queue` function + """ + mock_queue_properties.name = self.queue_name + mock_sb_admin_client.return_value.__enter__.return_value.create_queue.return_value = ( + mock_queue_properties + ) + hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) + response = hook.create_queue(self.queue_name) + assert response == mock_queue_properties + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.ServiceBusAdministrationClient') + def test_create_queue_exception(self, mock_sb_admin_client): + """Test `create_queue` functionality to raise ValueError by passing queue name as None""" + hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(TypeError): + hook.create_queue(None) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn') + def test_delete_queue(self, mock_sb_admin_client): + """ + Test Delete queue functionality by passing queue name, assert the function with values, + mock the azure service bus function `delete_queue` + """ + hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) + hook.delete_queue(self.queue_name) + expected_calls = [mock.call().__enter__().delete_queue(self.queue_name)] + mock_sb_admin_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.ServiceBusAdministrationClient') + def test_delete_queue_exception(self, mock_sb_admin_client): + """Test `delete_queue` functionality to raise ValueError, by passing queue name as None""" + hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(TypeError): + hook.delete_queue(None) + + +class TestMessageHook: + def setup_class(self) -> None: + self.queue_name: str = "test_queue" + self.conn_id: str = 'azure_service_bus_default' + self.connection_string = ( + "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" + "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" + ) + self.conn = Connection( + conn_id='azure_service_bus_default', + conn_type='azure_service_bus', + schema=self.connection_string, + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection") + def test_get_service_bus_message_conn(self, mock_connection): + """ + Test get_conn() function and check whether the get_conn() function returns value + is instance of ServiceBusClient + """ + mock_connection.return_value = self.conn + hook = MessageHook(azure_service_bus_conn_id=self.conn_id) + assert isinstance(hook.get_conn(), ServiceBusClient) + + @pytest.mark.parametrize( + "mock_message, mock_batch_flag", + [ + (MESSAGE, True), + (MESSAGE, False), + (MESSAGE_LIST, True), + (MESSAGE_LIST, False), + ], + ) + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_list_messages') + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_batch_message') + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn') + def test_send_message( + self, mock_sb_client, mock_batch_message, mock_list_message, mock_message, mock_batch_flag + ): + """ + Test `send_message` hook function with batch flag and message passed as mocked params, + which can be string or list of string, mock the azure service bus `send_messages` function + """ + hook = MessageHook(azure_service_bus_conn_id="azure_service_bus_default") + hook.send_message( + queue_name=self.queue_name, messages=mock_message, batch_message_flag=mock_batch_flag + ) + if isinstance(mock_message, list): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = [ServiceBusMessage(msg) for msg in mock_message] + elif isinstance(mock_message, str): + if mock_batch_flag: + message = ServiceBusMessageBatch(mock_message) + else: + message = ServiceBusMessage(mock_message) + + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(self.queue_name) + .__enter__() + .send_messages(message) + .__exit__() + ] + mock_sb_client.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn') + def test_send_message_exception(self, mock_sb_client): + """ + Test `send_message` functionality to raise AirflowException in Azure MessageHook + by passing queue name as None + """ + hook = MessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(TypeError): + hook.send_message(queue_name=None, messages="", batch_message_flag=False) + + @mock.patch('azure.servicebus.ServiceBusMessage') + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn') + def test_receive_message(self, mock_sb_client, mock_service_bus_message): + """ + Test `receive_message` hook function and assert the function with mock value, + mock the azure service bus `receive_messages` function + """ + hook = MessageHook(azure_service_bus_conn_id=self.conn_id) + mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [ + mock_service_bus_message + ] + hook.receive_message(self.queue_name) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=30, max_wait_time=5) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ + ] + mock_sb_client.assert_has_calls(expected_calls) + + @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn') + def test_receive_message_exception(self, mock_sb_client): + """ + Test `receive_message` functionality to raise AirflowException in Azure MessageHook + by passing queue name as None + """ + hook = MessageHook(azure_service_bus_conn_id=self.conn_id) + with pytest.raises(TypeError): + hook.receive_message(None) diff --git a/tests/providers/microsoft/azure/operators/test_asb.py b/tests/providers/microsoft/azure/operators/test_asb.py new file mode 100644 index 0000000000000..44fb6304340a0 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_asb.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +import pytest +from azure.servicebus import ServiceBusMessage + +from airflow.providers.microsoft.azure.operators.asb import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +QUEUE_NAME = "test_queue" +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + + +class TestAzureServiceBusCreateQueueOperator: + @pytest.mark.parametrize( + "mock_dl_msg_expiration, mock_batched_operation", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], + ) + def test_init(self, mock_dl_msg_expiration, mock_batched_operation): + """ + Test init by creating AzureServiceBusCreateQueueOperator with task id, + queue_name and asserting with value + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue", + queue_name=QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=mock_dl_msg_expiration, + enable_batched_operations=mock_batched_operation, + ) + assert asb_create_queue_operator.task_id == "asb_create_queue" + assert asb_create_queue_operator.queue_name == QUEUE_NAME + assert asb_create_queue_operator.max_delivery_count == 10 + assert asb_create_queue_operator.dead_lettering_on_message_expiration is mock_dl_msg_expiration + assert asb_create_queue_operator.enable_batched_operations is mock_batched_operation + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn") + def test_create_queue(self, mock_get_conn): + """ + Test AzureServiceBusCreateQueueOperator passed with the queue name, + mocking the connection details, hook create_queue function + """ + asb_create_queue_operator = AzureServiceBusCreateQueueOperator( + task_id="asb_create_queue_operator", + queue_name=QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + asb_create_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.create_queue.assert_called_once_with( + QUEUE_NAME, + max_delivery_count=10, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + + +class TestAzureServiceBusDeleteQueueOperator: + def test_init(self): + """ + Test init by creating AzureServiceBusDeleteQueueOperator with task id, queue_name and asserting + with values + """ + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + ) + assert asb_delete_queue_operator.task_id == "asb_delete_queue" + assert asb_delete_queue_operator.queue_name == QUEUE_NAME + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn") + def test_delete_queue(self, mock_get_conn): + """Test AzureServiceBusDeleteQueueOperator by mocking queue name, connection and hook delete_queue""" + asb_delete_queue_operator = AzureServiceBusDeleteQueueOperator( + task_id="asb_delete_queue", + queue_name=QUEUE_NAME, + ) + asb_delete_queue_operator.execute(None) + mock_get_conn.return_value.__enter__.return_value.delete_queue.assert_called_once_with(QUEUE_NAME) + + +class TestAzureServiceBusSendMessageOperator: + @pytest.mark.parametrize( + "mock_message, mock_batch_flag", + [ + (MESSAGE, True), + (MESSAGE, False), + (MESSAGE_LIST, True), + (MESSAGE_LIST, False), + ], + ) + def test_init(self, mock_message, mock_batch_flag): + """ + Test init by creating AzureServiceBusSendMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue_without_batch", + queue_name=QUEUE_NAME, + message=mock_message, + batch=mock_batch_flag, + ) + assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_without_batch" + assert asb_send_message_queue_operator.queue_name == QUEUE_NAME + assert asb_send_message_queue_operator.message == mock_message + assert asb_send_message_queue_operator.batch is mock_batch_flag + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn") + def test_send_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusSendMessageOperator with queue name, batch boolean flag, mock + the send_messages of azure service bus function + """ + asb_send_message_queue_operator = AzureServiceBusSendMessageOperator( + task_id="asb_send_message_queue", + queue_name=QUEUE_NAME, + message="Test message", + batch=False, + ) + asb_send_message_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_sender(QUEUE_NAME) + .__enter__() + .send_messages(ServiceBusMessage("Test message")) + .__exit__() + ] + mock_get_conn.assert_has_calls(expected_calls, any_order=False) + + +class TestAzureServiceBusReceiveMessageOperator: + def test_init(self): + """ + Test init by creating AzureServiceBusReceiveMessageOperator with task id, queue_name, message, + batch and asserting with values + """ + + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + ) + assert asb_receive_queue_operator.task_id == "asb_receive_message_queue" + assert asb_receive_queue_operator.queue_name == QUEUE_NAME + + @mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn") + def test_receive_message_queue(self, mock_get_conn): + """ + Test AzureServiceBusReceiveMessageOperator by mock connection, values + and the service bus receive message + """ + asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( + task_id="asb_receive_message_queue", + queue_name=QUEUE_NAME, + ) + asb_receive_queue_operator.execute(None) + expected_calls = [ + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .get_queue_receiver(QUEUE_NAME) + .__exit__() + .mock_call() + .__exit__ + ] + mock_get_conn.assert_has_calls(expected_calls) diff --git a/tests/system/providers/microsoft/azure/example_azure_service_bus.py b/tests/system/providers/microsoft/azure/example_azure_service_bus.py new file mode 100644 index 0000000000000..d754991690678 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_azure_service_bus.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.models.baseoperator import chain +from airflow.providers.microsoft.azure.operators.asb import ( + AzureServiceBusCreateQueueOperator, + AzureServiceBusDeleteQueueOperator, + AzureServiceBusReceiveMessageOperator, + AzureServiceBusSendMessageOperator, +) + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) + +CLIENT_ID = os.getenv("CLIENT_ID", "") +QUEUE_NAME = "sb_mgmt_queue_test" +MESSAGE = "Test Message" +MESSAGE_LIST = [MESSAGE + " " + str(n) for n in range(0, 10)] + +with DAG( + dag_id="example_azure_service_bus_queue", + start_date=datetime(2021, 8, 13), + schedule_interval=None, + catchup=False, + default_args={ + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "azure_service_bus_conn_id": "azure_service_bus_default", + }, + tags=["example", "Azure service bus Queue"], +) as dag: + # [START howto_operator_create_service_bus_queue] + create_service_bus_queue = AzureServiceBusCreateQueueOperator( + task_id="create_service_bus_queue", + queue_name=QUEUE_NAME, + ) + # [END howto_operator_create_service_bus_queue] + + # [START howto_operator_send_message_to_service_bus_queue] + send_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_message_to_service_bus_queue", + message=MESSAGE, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_message_to_service_bus_queue] + + # [START howto_operator_send_list_message_to_service_bus_queue] + send_list_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_list_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=False, + ) + # [END howto_operator_send_list_message_to_service_bus_queue] + + # [START howto_operator_send_batch_message_to_service_bus_queue] + send_batch_message_to_service_bus_queue = AzureServiceBusSendMessageOperator( + task_id="send_batch_message_to_service_bus_queue", + message=MESSAGE_LIST, + queue_name=QUEUE_NAME, + batch=True, + ) + # [END howto_operator_send_batch_message_to_service_bus_queue] + + # [START howto_operator_receive_message_service_bus_queue] + receive_message_service_bus_queue = AzureServiceBusReceiveMessageOperator( + task_id="receive_message_service_bus_queue", + queue_name=QUEUE_NAME, + max_message_count=20, + max_wait_time=5, + ) + # [END howto_operator_receive_message_service_bus_queue] + + # [START howto_operator_delete_service_bus_queue] + delete_service_bus_queue = AzureServiceBusDeleteQueueOperator( + task_id="delete_service_bus_queue", queue_name=QUEUE_NAME, trigger_rule="all_done" + ) + # [END howto_operator_delete_service_bus_queue] + + chain( + create_service_bus_queue, + send_message_to_service_bus_queue, + send_list_message_to_service_bus_queue, + send_batch_message_to_service_bus_queue, + receive_message_service_bus_queue, + delete_service_bus_queue, + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)