diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 46dd7d410..7e63933e1 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,11 +14,13 @@ """Firebase Cloud Messaging module.""" +import concurrent.futures import json +import warnings +import requests from googleapiclient import http from googleapiclient import _auth -import requests import firebase_admin from firebase_admin import _http_client @@ -26,6 +28,7 @@ from firebase_admin import _messaging_utils from firebase_admin import _gapic_utils from firebase_admin import _utils +from firebase_admin import exceptions _MESSAGING_ATTRIBUTE = '_messaging' @@ -115,6 +118,57 @@ def send(message, dry_run=False, app=None): """ return _get_messaging_service(app).send(message, dry_run) +def send_each(messages, dry_run=False, app=None): + """Sends each message in the given list via Firebase Cloud Messaging. + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + messages: A list of ``messaging.Message`` instances. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + return _get_messaging_service(app).send_each(messages, dry_run) + +def send_each_for_multicast(multicast_message, dry_run=False, app=None): + """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + multicast_message: An instance of ``messaging.MulticastMessage``. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + if not isinstance(multicast_message, MulticastMessage): + raise ValueError('Message must be an instance of messaging.MulticastMessage class.') + messages = [Message( + data=multicast_message.data, + notification=multicast_message.notification, + android=multicast_message.android, + webpush=multicast_message.webpush, + apns=multicast_message.apns, + fcm_options=multicast_message.fcm_options, + token=token + ) for token in multicast_message.tokens] + return _get_messaging_service(app).send_each(messages, dry_run) + def send_all(messages, dry_run=False, app=None): """Sends the given list of messages via Firebase Cloud Messaging as a single batch. @@ -132,7 +186,10 @@ def send_all(messages, dry_run=False, app=None): Raises: FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. + + send_all() is deprecated. Use send_each() instead. """ + warnings.warn('send_all() is deprecated. Use send_each() instead.', DeprecationWarning) return _get_messaging_service(app).send_all(messages, dry_run) def send_multicast(multicast_message, dry_run=False, app=None): @@ -152,7 +209,11 @@ def send_multicast(multicast_message, dry_run=False, app=None): Raises: FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. + + send_multicast() is deprecated. Use send_each_for_multicast() instead. """ + warnings.warn('send_multicast() is deprecated. Use send_each_for_multicast() instead.', + DeprecationWarning) if not isinstance(multicast_message, MulticastMessage): raise ValueError('Message must be an instance of messaging.MulticastMessage class.') messages = [Message( @@ -356,6 +417,35 @@ def send(self, message, dry_run=False): else: return resp['name'] + def send_each(self, messages, dry_run=False): + """Sends the given messages to FCM via the FCM v1 API.""" + if not isinstance(messages, list): + raise ValueError('messages must be a list of messaging.Message instances.') + if len(messages) > 500: + raise ValueError('messages must not contain more than 500 elements.') + + def send_data(data): + try: + resp = self._client.body( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data) + except requests.exceptions.RequestException as exception: + return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) + else: + return SendResponse(resp, exception=None) + + message_data = [self._message_data(message, dry_run) for message in messages] + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=len(message_data)) as executor: + responses = [resp for resp in executor.map(send_data, message_data)] + return BatchResponse(responses) + except Exception as error: + raise exceptions.UnknownError( + message='Unknown error while making remote service calls: {0}'.format(error), + cause=error) + def send_all(self, messages, dry_run=False): """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 3d8740cc1..71bb13eed 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1812,6 +1812,16 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() + def _instrument_messaging_service(self, response_dict, app=None): + if not app: + app = firebase_admin.get_app() + fcm_service = messaging._get_messaging_service(app) + recorder = [] + fcm_service._client.session.mount( + 'https://fcm.googleapis.com', + testutils.MockRequestBasedMultiRequestAdapter(response_dict, recorder)) + return fcm_service, recorder + def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): def build_mock_transport(_): if exc: @@ -1844,6 +1854,261 @@ def _batch_payload(self, payloads): return payload +class TestSendEach(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + messaging.send_each([messaging.Message(topic='foo')], app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('msg', NON_LIST_ARGS) + def test_invalid_send_each(self, msg): + with pytest.raises(ValueError) as excinfo: + messaging.send_each(msg) + if isinstance(msg, list): + expected = 'Message must be an instance of messaging.Message class.' + assert str(excinfo.value) == expected + else: + expected = 'messages must be a list of messaging.Message instances.' + assert str(excinfo.value) == expected + + def test_invalid_over_500(self): + msg = messaging.Message(topic='foo') + with pytest.raises(ValueError) as excinfo: + messaging.send_each([msg for _ in range(0, 501)]) + expected = 'messages must not contain more than 500 elements.' + assert str(excinfo.value) == expected + + def test_send_each(self): + payload1 = json.dumps({'name': 'message-id1'}) + payload2 = json.dumps({'name': 'message-id2'}) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, payload1], 'foo2': [200, payload2]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2], dry_run=True) + assert batch_response.success_count == 2 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 2 + assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_detailed_error(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2]) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + exception = error_response.exception + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_canonical_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2]) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + exception = error_response.exception + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + def test_send_each_fcm_error_code(self, status, fcm_error_code, exc_type): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': fcm_error_code, + }, + ], + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2]) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + exception = error_response.exception + assert isinstance(exception, exc_type) + check_exception(exception, 'test error', status) + + +class TestSendEachForMulticast(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + messaging.send_all([messaging.Message(topic='foo')], app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('msg', NON_LIST_ARGS) + def test_invalid_send_each_for_multicast(self, msg): + with pytest.raises(ValueError) as excinfo: + messaging.send_multicast(msg) + expected = 'Message must be an instance of messaging.MulticastMessage class.' + assert str(excinfo.value) == expected + + def test_send_each_for_multicast(self): + payload1 = json.dumps({'name': 'message-id1'}) + payload2 = json.dumps({'name': 'message-id2'}) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, payload1], 'foo2': [200, payload2]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg, dry_run=True) + assert batch_response.success_count == 2 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 2 + assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_for_multicast_detailed_error(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_for_multicast_canonical_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_for_multicast_fcm_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'UNREGISTERED', + }, + ], + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert isinstance(exception, messaging.UnregisteredError) + check_exception(exception, 'test error', status) + + class TestSendAll(TestBatch): def test_no_project_id(self): diff --git a/tests/testutils.py b/tests/testutils.py index 92755107c..e52b90d1a 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -171,3 +171,33 @@ def status(self): @property def data(self): return self._responses[0] + +class MockRequestBasedMultiRequestAdapter(adapters.HTTPAdapter): + """A mock HTTP adapter that supports multiple responses for the Python requests module. + The response for each incoming request should be specified in response_dict during + initialization. Each incoming request should contain an identifier in the its body.""" + def __init__(self, response_dict, recorder): + """Constructs a MockRequestBasedMultiRequestAdapter. + + Each incoming request consumes the response and status mapped to it. If no response + is specified for the request, the response will be 404 with an empty body. + """ + adapters.HTTPAdapter.__init__(self) + self._current_response = 0 + self._response_dict = dict(response_dict) + self._recorder = recorder + + def send(self, request, **kwargs): # pylint: disable=arguments-differ + request._extra_kwargs = kwargs + self._recorder.append(request) + resp = models.Response() + resp.url = request.url + resp.status_code = 404 # Not found. + resp.raw = None + for req_id, pair in self._response_dict.items(): + if req_id in str(request.body): + status, response = pair + resp.status_code = status + resp.raw = io.BytesIO(response.encode()) + break + return resp