-
Notifications
You must be signed in to change notification settings - Fork 192
/
rmq.py
229 lines (180 loc) · 9.4 KB
/
rmq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=cyclic-import
"""Components to communicate tasks to RabbitMQ."""
import asyncio
from collections.abc import Mapping
import logging
import traceback
from kiwipy import communications, Future
import pamqp.encode
import plumpy
from aiida.common.extendeddicts import AttributeDict
__all__ = ('RemoteException', 'CommunicationTimeout', 'DeliveryFailed', 'ProcessLauncher', 'BROKER_DEFAULTS')
# The following statement enables support for RabbitMQ 3.5 because without it, connections established by `aiormq` will
# fail because the interpretation of the types of integers passed in connection parameters has changed after that
# version. Once RabbitMQ 3.5 is no longer supported (it has been EOL since October 2016) this can be removed. This
# should also allow to remove the direct dependency on `pamqp` entirely.
pamqp.encode.support_deprecated_rabbitmq()
LOGGER = logging.getLogger(__name__)
RemoteException = plumpy.RemoteException
DeliveryFailed = plumpy.DeliveryFailed
CommunicationTimeout = communications.TimeoutError # pylint: disable=invalid-name
_LAUNCH_QUEUE = 'process.queue'
_MESSAGE_EXCHANGE = 'messages'
_TASK_EXCHANGE = 'tasks'
BROKER_DEFAULTS = AttributeDict({
'protocol': 'amqp',
'username': 'guest',
'password': 'guest',
'host': '127.0.0.1',
'port': 5672,
'virtual_host': '',
'heartbeat': 600,
})
def get_rmq_url(protocol=None, username=None, password=None, host=None, port=None, virtual_host=None, **kwargs):
"""Return the URL to connect to RabbitMQ.
.. note::
The default of the ``host`` is set to ``127.0.0.1`` instead of ``localhost`` because on some computers localhost
resolves first to IPv6 with address ::1 and if RMQ is not running on IPv6 one gets an annoying warning. For more
info see: https://github.com/aiidateam/aiida-core/issues/1142
:param protocol: the protocol to use, `amqp` or `amqps`.
:param username: the username for authentication.
:param password: the password for authentication.
:param host: the hostname of the RabbitMQ server.
:param port: the port of the RabbitMQ server.
:param virtual_host: the virtual host to connect to.
:param kwargs: remaining keyword arguments that will be encoded as query parameters.
:returns: the connection URL string.
"""
from urllib.parse import urlencode, urlunparse
if 'heartbeat' not in kwargs:
kwargs['heartbeat'] = BROKER_DEFAULTS.heartbeat
scheme = protocol or BROKER_DEFAULTS.protocol
netloc = '{username}:{password}@{host}:{port}'.format(
username=username or BROKER_DEFAULTS.username,
password=password or BROKER_DEFAULTS.password,
host=host or BROKER_DEFAULTS.host,
port=port or BROKER_DEFAULTS.port,
)
path = virtual_host or BROKER_DEFAULTS.virtual_host
parameters = ''
query = urlencode(kwargs)
fragment = ''
# The virtual host is optional but if it is specified it needs to start with a forward slash. If the virtual host
# itself contains forward slashes, they need to be encoded.
if path and not path.startswith('/'):
path = f'/{path}'
return urlunparse((scheme, netloc, path, parameters, query, fragment))
def get_launch_queue_name(prefix=None):
"""Return the launch queue name with an optional prefix.
:returns: launch queue name
"""
if prefix is not None:
return f'{prefix}.{_LAUNCH_QUEUE}'
return _LAUNCH_QUEUE
def get_message_exchange_name(prefix):
"""Return the message exchange name for a given prefix.
:returns: message exchange name
"""
return f'{prefix}.{_MESSAGE_EXCHANGE}'
def get_task_exchange_name(prefix):
"""Return the task exchange name for a given prefix.
:returns: task exchange name
"""
return f'{prefix}.{_TASK_EXCHANGE}'
def _store_inputs(inputs):
"""Try to store the values in the input dictionary.
For nested dictionaries, the values are stored by recursively.
"""
for node in inputs.values():
try:
node.store()
except AttributeError:
if isinstance(node, Mapping):
_store_inputs(node)
class ProcessLauncher(plumpy.ProcessLauncher):
"""A sub class of `plumpy.ProcessLauncher` to launch a `Process`.
It overrides the _continue method to make sure the node corresponding to the task can be loaded and
that if it is already marked as terminated, it is not continued but the future is reconstructed and returned
"""
@staticmethod
def handle_continue_exception(node, exception, message):
"""Handle exception raised in `_continue` call.
If the process state of the node has not yet been put to excepted, the exception was raised before the process
instance could be reconstructed, for example when the process class could not be loaded, thereby circumventing
the exception handling of the state machine. Raising this exception will then acknowledge the process task with
RabbitMQ leaving an uncleaned node in the `CREATED` state for ever. Therefore we have to perform the node
cleaning manually.
:param exception: the exception object
:param message: string message to use for the log message
"""
from aiida.engine import ProcessState
if not node.is_excepted and not node.is_sealed:
node.logger.exception(message)
node.set_exception(''.join(traceback.format_exception(type(exception), exception, None)).rstrip())
node.set_process_state(ProcessState.EXCEPTED)
node.seal()
async def _continue(self, communicator, pid, nowait, tag=None):
"""Continue the task.
Note that the task may already have been completed, as indicated from the corresponding the node, in which
case it is not continued, but the corresponding future is reconstructed and returned. This scenario may
occur when the Process was already completed by another worker that however failed to send the acknowledgment.
:param communicator: the communicator that called this method
:param pid: the pid of the process to continue
:param nowait: if True don't wait for the process to finish, just return the pid, otherwise wait and
return the results
:param tag: the tag of the checkpoint to continue from
"""
from aiida.common import exceptions
from aiida.engine.exceptions import PastException
from aiida.orm import load_node, Data
from aiida.orm.utils import serialize
try:
node = load_node(pk=pid)
except (exceptions.MultipleObjectsError, exceptions.NotExistent) as exception:
# In this case, the process node corresponding to the process id, cannot be resolved uniquely or does not
# exist. The latter being the most common case, where someone deleted the node, before the process was
# properly terminated. Since the node is never coming back and so the process will never be able to continue
# we raise `Return` instead of `TaskRejected` because the latter would cause the task to be resent and start
# to ping-pong between RabbitMQ and the daemon workers.
LOGGER.exception('Cannot continue process<%d>', pid)
return False
if node.is_terminated:
LOGGER.info('not continuing process<%d> which is already terminated with state %s', pid, node.process_state)
future = Future()
if node.is_finished:
future.set_result({entry.link_label: entry.node for entry in node.get_outgoing(node_class=Data)})
elif node.is_excepted:
future.set_exception(PastException(node.exception))
elif node.is_killed:
future.set_exception(plumpy.KilledError())
return future.result()
try:
result = await super()._continue(communicator, pid, nowait, tag)
except ImportError as exception:
message = 'the class of the process could not be imported.'
self.handle_continue_exception(node, exception, message)
raise
except asyncio.CancelledError: # pylint: disable=try-except-raise
# note this is only required in python<=3.7,
# where asyncio.CancelledError inherits from Exception
raise
except Exception as exception:
message = 'failed to recreate the process instance in order to continue it.'
self.handle_continue_exception(node, exception, message)
raise
# Ensure that the result is serialized such that communication thread won't have to do database operations
try:
serialized = serialize.serialize(result)
except Exception:
LOGGER.exception('failed to serialize the result for process<%d>', pid)
raise
return serialized