diff --git a/tests/providers/microsoft/azure/hooks/test_asb_message.py b/tests/providers/microsoft/azure/hooks/test_asb_message.py index bbe58056f9817..bc1d93c366009 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb_message.py +++ b/tests/providers/microsoft/azure/hooks/test_asb_message.py @@ -121,23 +121,28 @@ def test_send_message_exception(self, mock_sb_client): with pytest.raises(AirflowException): hook.send_message(queue_name=None, messages="", batch_message_flag=False) + @mock.patch('azure.servicebus.ServiceBusMessage') @mock.patch('airflow.providers.microsoft.azure.hooks.asb_message.ServiceBusMessageHook.get_conn') - def test_receive_message(self, mock_sb_client): + def test_receive_message(self, mock_sb_client, mock_service_bus_message): """ Test `receive_message` hook function and assert the function with mock value, mock the azure service bus `receive_messages` function """ hook = ServiceBusMessageHook(azure_service_bus_conn_id=self.conn_id) + mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [ + mock_service_bus_message + ] hook.receive_message(self.queue_name) expected_calls = [ - mock.call(), - mock.call().__enter__(), - mock.call().get_queue_receiver(queue_name='test_queue'), - mock.call().get_queue_receiver().__enter__(), - mock.call().get_queue_receiver().receive_messages(max_message_count=1, max_wait_time=None), - mock.call().get_queue_receiver().receive_messages().__iter__(), - mock.call().get_queue_receiver().__exit__(None, None, None), - mock.call().__exit__(None, None, None), + mock.call() + .__enter__() + .get_queue_receiver(self.queue_name) + .__enter__() + .receive_messages(max_message_count=30, max_wait_time=5) + .get_queue_receiver(self.queue_name) + .__exit__() + .mock_call() + .__exit__ ] mock_sb_client.assert_has_calls(expected_calls) diff --git a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py index d307234462732..764d4d28d7920 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_service_queue.py +++ b/tests/providers/microsoft/azure/operators/test_azure_service_queue.py @@ -221,13 +221,14 @@ def test_receive_message_queue(self, mock_get_conn): ) asb_receive_queue_operator.execute(None) expected_calls = [ - mock.call(), - mock.call().__enter__(), - mock.call().get_queue_receiver(queue_name='test_queue'), - mock.call().get_queue_receiver().__enter__(), - mock.call().get_queue_receiver().receive_messages(max_message_count=10, max_wait_time=5), - mock.call().get_queue_receiver().receive_messages().__iter__(), - mock.call().get_queue_receiver().__exit__(None, None, None), - mock.call().__exit__(None, None, None), + mock.call() + .__enter__() + .get_queue_receiver(QUEUE_NAME) + .__enter__() + .receive_messages(max_message_count=10, max_wait_time=5) + .get_queue_receiver(QUEUE_NAME) + .__exit__() + .mock_call() + .__exit__ ] mock_get_conn.assert_has_calls(expected_calls)