diff --git a/azure/functions/__init__.py b/azure/functions/__init__.py index 86565e37..eee6c1c9 100644 --- a/azure/functions/__init__.py +++ b/azure/functions/__init__.py @@ -23,6 +23,7 @@ from ._queue import QueueMessage from ._servicebus import ServiceBusMessage from ._sql import SqlRow, SqlRowList +from ._mysql import MySqlRow, MySqlRowList # Import binding implementations to register them from . import blob # NoQA @@ -37,6 +38,7 @@ from . import durable_functions # NoQA from . import sql # NoQA from . import warmup # NoQA +from . import mysql # NoQA __all__ = ( @@ -67,6 +69,8 @@ 'SqlRowList', 'TimerRequest', 'WarmUpContext', + 'MySqlRow', + 'MySqlRowList', # Middlewares 'WsgiMiddleware', diff --git a/azure/functions/_mysql.py b/azure/functions/_mysql.py new file mode 100644 index 00000000..9c7515d9 --- /dev/null +++ b/azure/functions/_mysql.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import abc +import collections +import json + + +class BaseMySqlRow(abc.ABC): + + @classmethod + @abc.abstractmethod + def from_json(cls, json_data: str) -> 'BaseMySqlRow': + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def from_dict(cls, dct: dict) -> 'BaseMySqlRow': + raise NotImplementedError + + @abc.abstractmethod + def __getitem__(self, key): + raise NotImplementedError + + @abc.abstractmethod + def __setitem__(self, key, value): + raise NotImplementedError + + @abc.abstractmethod + def to_json(self) -> str: + raise NotImplementedError + + +class BaseMySqlRowList(abc.ABC): + pass + + +class MySqlRow(BaseMySqlRow, collections.UserDict): + """A MySql Row. + + MySqlRow objects are ''UserDict'' subclasses and behave like dicts. + """ + + @classmethod + def from_json(cls, json_data: str) -> 'BaseMySqlRow': + """Create a MySqlRow from a JSON string.""" + return cls.from_dict(json.loads(json_data)) + + @classmethod + def from_dict(cls, dct: dict) -> 'BaseMySqlRow': + """Create a MySqlRow from a dict object""" + return cls({k: v for k, v in dct.items()}) + + def to_json(self) -> str: + """Return the JSON representation of the MySqlRow""" + return json.dumps(dict(self)) + + def __getitem__(self, key): + return collections.UserDict.__getitem__(self, key) + + def __setitem__(self, key, value): + return collections.UserDict.__setitem__(self, key, value) + + def __repr__(self) -> str: + return ( + f'' + ) + + +class MySqlRowList(BaseMySqlRowList, collections.UserList): + "A ''UserList'' subclass containing a list of :class:'~MySqlRow' objects" + pass diff --git a/azure/functions/decorators/blob.py b/azure/functions/decorators/blob.py index 1a2d4122..bd2861fa 100644 --- a/azure/functions/decorators/blob.py +++ b/azure/functions/decorators/blob.py @@ -17,7 +17,7 @@ def __init__(self, **kwargs): self.path = path self.connection = connection - self.source = source + self.source = source.value if source else None super().__init__(name=name, data_type=data_type) @staticmethod diff --git a/azure/functions/mysql.py b/azure/functions/mysql.py new file mode 100644 index 00000000..06a04a56 --- /dev/null +++ b/azure/functions/mysql.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import collections.abc +import json +import typing + +from azure.functions import _mysql as mysql + +from . import meta + + +class MySqlConverter(meta.InConverter, meta.OutConverter, + binding='mysql'): + + @classmethod + def check_input_type_annotation(cls, pytype: type) -> bool: + return issubclass(pytype, mysql.BaseMySqlRowList) + + @classmethod + def check_output_type_annotation(cls, pytype: type) -> bool: + return issubclass(pytype, (mysql.BaseMySqlRowList, mysql.BaseMySqlRow)) + + @classmethod + def decode(cls, + data: meta.Datum, + *, + trigger_metadata) -> typing.Optional[mysql.MySqlRowList]: + if data is None or data.type is None: + return None + + data_type = data.type + + if data_type in ['string', 'json']: + body = data.value + + elif data_type == 'bytes': + body = data.value.decode('utf-8') + + else: + raise NotImplementedError( + f'Unsupported payload type: {data_type}') + + rows = json.loads(body) + if not isinstance(rows, list): + rows = [rows] + + return mysql.MySqlRowList( + (None if row is None else mysql.MySqlRow.from_dict(row)) + for row in rows) + + @classmethod + def encode(cls, obj: typing.Any, *, + expected_type: typing.Optional[type]) -> meta.Datum: + if isinstance(obj, mysql.MySqlRow): + data = mysql.MySqlRowList([obj]) + + elif isinstance(obj, mysql.MySqlRowList): + data = obj + + elif isinstance(obj, collections.abc.Iterable): + data = mysql.MySqlRowList() + + for row in obj: + if not isinstance(row, mysql.MySqlRow): + raise NotImplementedError( + f'Unsupported list type: {type(obj)}, \ + lists must contain MySqlRow objects') + else: + data.append(row) + + else: + raise NotImplementedError(f'Unsupported type: {type(obj)}') + + return meta.Datum( + type='json', + value=json.dumps([dict(d) for d in data]) + ) diff --git a/tests/decorators/test_blob.py b/tests/decorators/test_blob.py index f8712b9f..43926591 100644 --- a/tests/decorators/test_blob.py +++ b/tests/decorators/test_blob.py @@ -42,7 +42,7 @@ def test_blob_trigger_creation_with_default_specified_source(self): "name": "req", "dataType": DataType.UNDEFINED, "path": "dummy_path", - 'source': BlobSource.LOGS_AND_CONTAINER_SCAN, + 'source': 'LogsAndContainerScan', "connection": "dummy_connection" }) @@ -62,7 +62,7 @@ def test_blob_trigger_creation_with_source_as_string(self): "name": "req", "dataType": DataType.UNDEFINED, "path": "dummy_path", - 'source': BlobSource.EVENT_GRID, + 'source': 'EventGrid', "connection": "dummy_connection" }) @@ -82,7 +82,7 @@ def test_blob_trigger_creation_with_source_as_enum(self): "name": "req", "dataType": DataType.UNDEFINED, "path": "dummy_path", - 'source': BlobSource.EVENT_GRID, + 'source': 'EventGrid', "connection": "dummy_connection" }) diff --git a/tests/decorators/test_decorators.py b/tests/decorators/test_decorators.py index 82973ba3..c801fd8d 100644 --- a/tests/decorators/test_decorators.py +++ b/tests/decorators/test_decorators.py @@ -1628,7 +1628,7 @@ def test_blob_input_binding(): "type": BLOB_TRIGGER, "name": "req", "path": "dummy_path", - "source": BlobSource.EVENT_GRID, + "source": "EventGrid", "connection": "dummy_conn" }) diff --git a/tests/test_mysql.py b/tests/test_mysql.py new file mode 100644 index 00000000..514c066a --- /dev/null +++ b/tests/test_mysql.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import unittest + +import azure.functions as func +import azure.functions.mysql as mysql +from azure.functions.meta import Datum + + +class TestMySql(unittest.TestCase): + def test_mysql_decode_none(self): + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=None, trigger_metadata=None) + self.assertIsNone(result) + + def test_mysql_decode_string(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "string") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['id'], + '1', + 'MySqlRow item should have id 1') + self.assertEqual(result[0]['name'], + 'test', + 'MySqlRow item should have name test') + + def test_mysql_decode_bytes(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """.encode(), "bytes") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['id'], + '1', + 'MySqlRow item should have id 1') + self.assertEqual(result[0]['name'], + 'test', + 'MySqlRow item should have name test') + + def test_mysql_decode_json(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['id'], + '1', + 'MySqlRow item should have id 1') + self.assertEqual(result[0]['name'], + 'test', + 'MySqlRow item should have name test') + + def test_mysql_decode_json_name_is_null(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": null + } + """, "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result, + 'MySqlRowList itself should be non-None') + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0]['name'], + None, + 'Item in MySqlRowList should be None') + + def test_mysql_decode_json_multiple_entries(self): + datum: Datum = Datum(""" + [ + { + "id": "1", + "name": "test1" + }, + { + "id": "2", + "name": "test2" + } + ] + """, "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result) + self.assertEqual(len(result), + 2, + 'MySqlRowList should have exactly 2 items') + self.assertEqual(result[0]['id'], + '1', + 'First MySqlRowList item should have id 1') + self.assertEqual(result[0]['name'], + 'test1', + 'First MySqlRowList item should have name test1') + self.assertEqual(result[1]['id'], + '2', + 'First MySqlRowList item should have id 2') + self.assertEqual(result[1]['name'], + 'test2', + 'Second MySqlRowList item should have name test2') + + def test_mysql_decode_json_multiple_nulls(self): + datum: Datum = Datum("[null]", "json") + result: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + self.assertIsNotNone(result) + self.assertEqual(len(result), + 1, + 'MySqlRowList should have exactly 1 item') + self.assertEqual(result[0], + None, + 'MySqlRow item should be None') + + def test_mysql_encode_mysqlrow(self): + mysqlRow = func.MySqlRow.from_json(""" + { + "id": "1", + "name": "test" + } + """) + datum = mysql.MySqlConverter.encode(obj=mysqlRow, expected_type=None) + self.assertEqual(datum.type, + 'json', + 'Datum type should be JSON') + self.assertEqual(len(datum.python_value), + 1, + 'Encoded value should be list of length 1') + self.assertEqual(datum.python_value[0]['id'], + '1', + 'id should be 1') + self.assertEqual(datum.python_value[0]['name'], + 'test', + 'name should be test') + + def test_mysql_encode_mysqlrowlist(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "json") + mysqlRowList: func.MySqlRowList = mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + datum = mysql.MySqlConverter.encode( + obj=mysqlRowList, expected_type=None) + self.assertEqual(datum.type, + 'json', + 'Datum type should be JSON') + self.assertEqual(len(datum.python_value), + 1, + 'Encoded value should be list of length 1') + self.assertEqual(datum.python_value[0]['id'], + '1', + 'id should be 1') + self.assertEqual(datum.python_value[0]['name'], + 'test', + 'name should be test') + + def test_mysql_encode_list_of_mysqlrows(self): + mysqlRows = [ + func.MySqlRow.from_json(""" + { + "id": "1", + "name": "test" + } + """), + func.MySqlRow.from_json(""" + { + "id": "2", + "name": "test2" + } + """) + ] + datum = mysql.MySqlConverter.encode(obj=mysqlRows, expected_type=None) + self.assertEqual(datum.type, + 'json', + 'Datum type should be JSON') + self.assertEqual(len(datum.python_value), + 2, + 'Encoded value should be list of length 2') + self.assertEqual(datum.python_value[0]['id'], + '1', + 'id should be 1') + self.assertEqual(datum.python_value[0]['name'], + 'test', + 'name should be test') + self.assertEqual(datum.python_value[1]['id'], + '2', + 'id should be 2') + self.assertEqual(datum.python_value[1]['name'], + 'test2', + 'name should be test2') + + def test_mysql_encode_list_of_str_raises(self): + strList = [ + """ + { + "id": "1", + "name": "test" + } + """ + ] + self.assertRaises(NotImplementedError, + mysql.MySqlConverter.encode, + obj=strList, + expected_type=None) + + def test_mysql_encode_list_of_mysqlrowlist_raises(self): + datum: Datum = Datum(""" + { + "id": "1", + "name": "test" + } + """, "json") + mysqlRowListList = [ + mysql.MySqlConverter.decode( + data=datum, trigger_metadata=None) + ] + self.assertRaises(NotImplementedError, + mysql.MySqlConverter.encode, + obj=mysqlRowListList, + expected_type=None) + + def test_mysql_input_type(self): + check_input_type = mysql.MySqlConverter.check_input_type_annotation + self.assertTrue(check_input_type(func.MySqlRowList), + 'MySqlRowList should be accepted') + self.assertFalse(check_input_type(func.MySqlRow), + 'MySqlRow should not be accepted') + self.assertFalse(check_input_type(str), + 'str should not be accepted') + + def test_mysql_output_type(self): + check_output_type = mysql.MySqlConverter.check_output_type_annotation + self.assertTrue(check_output_type(func.MySqlRowList), + 'MySqlRowList should be accepted') + self.assertTrue(check_output_type(func.MySqlRow), + 'MySqlRow should be accepted') + self.assertFalse(check_output_type(str), + 'str should not be accepted') + + def test_mysqlrow_json(self): + # Parse MySqlRow from JSON + mysqlRow = func.MySqlRow.from_json(""" + { + "id": "1", + "name": "test" + } + """) + self.assertEqual(mysqlRow['id'], + '1', + 'Parsed MySqlRow id should be 1') + self.assertEqual(mysqlRow['name'], + 'test', + 'Parsed MySqlRow name should be test') + + # Parse JSON from MySqlRow + mysqlRowJson = json.loads(func.MySqlRow.to_json(mysqlRow)) + self.assertEqual(mysqlRowJson['id'], + '1', + 'Parsed JSON id should be 1') + self.assertEqual(mysqlRowJson['name'], + 'test', + 'Parsed JSON name should be test')