From 922fcd105502a07c7e550dcabceac9284de29307 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 17 Aug 2023 11:50:47 -0700 Subject: [PATCH] Add support for tuple-format custom codecs on composite types (#1061) It is now possible to `set_type_codec('mycomposite', ... format='tuple')`, which is useful for types that are represented by a composite type in Postgres, but are an integral type in Python, e.g. `complex`. Fixes: #1060 --- asyncpg/connection.py | 34 ++++++++++-- asyncpg/introspection.py | 4 ++ asyncpg/protocol/codecs/base.pxd | 3 ++ asyncpg/protocol/codecs/base.pyx | 90 ++++++++++++++++++++++---------- asyncpg/protocol/settings.pxd | 2 +- asyncpg/protocol/settings.pyx | 6 ++- docs/usage.rst | 43 ++++++++++++++- tests/test_codecs.py | 66 ++++++++++++++++------- 8 files changed, 192 insertions(+), 56 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 2d689512..45cf99b1 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1160,6 +1160,9 @@ async def set_type_codec(self, typename, *, | ``time with | (``microseconds``, | | time zone`` | ``time zone offset in seconds``) | +-----------------+---------------------------------------------+ + | any composite | Composite value elements | + | type | | + +-----------------+---------------------------------------------+ :param encoder: Callable accepting a Python object as a single argument and @@ -1214,6 +1217,10 @@ async def set_type_codec(self, typename, *, The ``binary`` keyword argument was removed in favor of ``format``. + .. versionchanged:: 0.29.0 + Custom codecs for composite types are now supported with + ``format='tuple'``. + .. note:: It is recommended to use the ``'binary'`` or ``'tuple'`` *format* @@ -1224,11 +1231,28 @@ async def set_type_codec(self, typename, *, codecs. """ self._check_open() + settings = self._protocol.get_settings() typeinfo = await self._introspect_type(typename, schema) - if not introspection.is_scalar_type(typeinfo): + full_typeinfos = [] + if introspection.is_scalar_type(typeinfo): + kind = 'scalar' + elif introspection.is_composite_type(typeinfo): + if format != 'tuple': + raise exceptions.UnsupportedClientFeatureError( + 'only tuple-format codecs can be used on composite types', + hint="Use `set_type_codec(..., format='tuple')` and " + "pass/interpret data as a Python tuple. See an " + "example at https://magicstack.github.io/asyncpg/" + "current/usage.html#example-decoding-complex-types", + ) + kind = 'composite' + full_typeinfos, _ = await self._introspect_types( + (typeinfo['oid'],), 10) + else: raise exceptions.InterfaceError( - 'cannot use custom codec on non-scalar type {}.{}'.format( - schema, typename)) + f'cannot use custom codec on type {schema}.{typename}: ' + f'it is neither a scalar type nor a composite type' + ) if introspection.is_domain_type(typeinfo): raise exceptions.UnsupportedClientFeatureError( 'custom codecs on domain types are not supported', @@ -1240,8 +1264,8 @@ async def set_type_codec(self, typename, *, ) oid = typeinfo['oid'] - self._protocol.get_settings().add_python_codec( - oid, typename, schema, 'scalar', + settings.add_python_codec( + oid, typename, schema, full_typeinfos, kind, encoder, decoder, format) # Statement cache is no longer valid due to codec changes. diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index d62f39a0..6c2caf03 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -286,3 +286,7 @@ def is_scalar_type(typeinfo) -> bool: def is_domain_type(typeinfo) -> bool: return typeinfo['kind'] == b'd' + + +def is_composite_type(typeinfo) -> bool: + return typeinfo['kind'] == b'c' diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd index 16928b88..1cfed833 100644 --- a/asyncpg/protocol/codecs/base.pxd +++ b/asyncpg/protocol/codecs/base.pxd @@ -57,6 +57,7 @@ cdef class Codec: encode_func c_encoder decode_func c_decoder + Codec base_codec object py_encoder object py_decoder @@ -79,6 +80,7 @@ cdef class Codec: CodecType type, ServerDataFormat format, ClientExchangeFormat xformat, encode_func c_encoder, decode_func c_decoder, + Codec base_codec, object py_encoder, object py_decoder, Codec element_codec, tuple element_type_oids, object element_names, list element_codecs, @@ -169,6 +171,7 @@ cdef class Codec: object decoder, encode_func c_encoder, decode_func c_decoder, + Codec base_codec, ServerDataFormat format, ClientExchangeFormat xformat) diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index 273b27aa..c269e374 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -23,14 +23,25 @@ cdef class Codec: self.oid = oid self.type = CODEC_UNDEFINED - cdef init(self, str name, str schema, str kind, - CodecType type, ServerDataFormat format, - ClientExchangeFormat xformat, - encode_func c_encoder, decode_func c_decoder, - object py_encoder, object py_decoder, - Codec element_codec, tuple element_type_oids, - object element_names, list element_codecs, - Py_UCS4 element_delimiter): + cdef init( + self, + str name, + str schema, + str kind, + CodecType type, + ServerDataFormat format, + ClientExchangeFormat xformat, + encode_func c_encoder, + decode_func c_decoder, + Codec base_codec, + object py_encoder, + object py_decoder, + Codec element_codec, + tuple element_type_oids, + object element_names, + list element_codecs, + Py_UCS4 element_delimiter, + ): self.name = name self.schema = schema @@ -40,6 +51,7 @@ cdef class Codec: self.xformat = xformat self.c_encoder = c_encoder self.c_decoder = c_decoder + self.base_codec = base_codec self.py_encoder = py_encoder self.py_decoder = py_decoder self.element_codec = element_codec @@ -48,6 +60,12 @@ cdef class Codec: self.element_delimiter = element_delimiter self.element_names = element_names + if base_codec is not None: + if c_encoder != NULL or c_decoder != NULL: + raise exceptions.InternalClientError( + 'base_codec is mutually exclusive with c_encoder/c_decoder' + ) + if element_names is not None: self.record_desc = record.ApgRecordDesc_New( element_names, tuple(element_names)) @@ -98,7 +116,7 @@ cdef class Codec: codec = Codec(self.oid) codec.init(self.name, self.schema, self.kind, self.type, self.format, self.xformat, - self.c_encoder, self.c_decoder, + self.c_encoder, self.c_decoder, self.base_codec, self.py_encoder, self.py_decoder, self.element_codec, self.element_type_oids, self.element_names, @@ -196,7 +214,10 @@ cdef class Codec: raise exceptions.InternalClientError( 'unexpected data format: {}'.format(self.format)) elif self.xformat == PG_XFORMAT_TUPLE: - self.c_encoder(settings, buf, data) + if self.base_codec is not None: + self.base_codec.encode(settings, buf, data) + else: + self.c_encoder(settings, buf, data) else: raise exceptions.InternalClientError( 'unexpected exchange format: {}'.format(self.xformat)) @@ -295,7 +316,10 @@ cdef class Codec: raise exceptions.InternalClientError( 'unexpected data format: {}'.format(self.format)) elif self.xformat == PG_XFORMAT_TUPLE: - data = self.c_decoder(settings, buf) + if self.base_codec is not None: + data = self.base_codec.decode(settings, buf) + else: + data = self.c_decoder(settings, buf) else: raise exceptions.InternalClientError( 'unexpected exchange format: {}'.format(self.xformat)) @@ -367,8 +391,8 @@ cdef class Codec: cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format, - PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec, - None, None, None, element_delimiter) + PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + element_codec, None, None, None, element_delimiter) return codec @staticmethod @@ -379,8 +403,8 @@ cdef class Codec: cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format, - PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec, - None, None, None, 0) + PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, + element_codec, None, None, None, 0) return codec @staticmethod @@ -391,7 +415,7 @@ cdef class Codec: cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'multirange', CODEC_MULTIRANGE, - element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, + element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, element_codec, None, None, None, 0) return codec @@ -407,7 +431,7 @@ cdef class Codec: codec = Codec(oid) codec.init(name, schema, 'composite', CODEC_COMPOSITE, format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, - element_type_oids, element_names, element_codecs, 0) + None, element_type_oids, element_names, element_codecs, 0) return codec @staticmethod @@ -419,12 +443,13 @@ cdef class Codec: object decoder, encode_func c_encoder, decode_func c_decoder, + Codec base_codec, ServerDataFormat format, ClientExchangeFormat xformat): cdef Codec codec codec = Codec(oid) codec.init(name, schema, kind, CODEC_PY, format, xformat, - c_encoder, c_decoder, encoder, decoder, + c_encoder, c_decoder, base_codec, encoder, decoder, None, None, None, None, 0) return codec @@ -596,17 +621,21 @@ cdef class DataCodecConfig: self.declare_fallback_codec(oid, name, schema) def add_python_codec(self, typeoid, typename, typeschema, typekind, - encoder, decoder, format, xformat): + typeinfos, encoder, decoder, format, xformat): cdef: - Codec core_codec + Codec core_codec = None encode_func c_encoder = NULL decode_func c_decoder = NULL + Codec base_codec = None uint32_t oid = pylong_as_oid(typeoid) bint codec_set = False # Clear all previous overrides (this also clears type cache). self.remove_python_codec(typeoid, typename, typeschema) + if typeinfos: + self.add_types(typeinfos) + if format == PG_FORMAT_ANY: formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY) else: @@ -614,16 +643,21 @@ cdef class DataCodecConfig: for fmt in formats: if xformat == PG_XFORMAT_TUPLE: - core_codec = get_core_codec(oid, fmt, xformat) - if core_codec is None: - continue - c_encoder = core_codec.c_encoder - c_decoder = core_codec.c_decoder + if typekind == "scalar": + core_codec = get_core_codec(oid, fmt, xformat) + if core_codec is None: + continue + c_encoder = core_codec.c_encoder + c_decoder = core_codec.c_decoder + elif typekind == "composite": + base_codec = self.get_codec(oid, fmt) + if base_codec is None: + continue self._custom_type_codecs[typeoid, fmt] = \ Codec.new_python_codec(oid, typename, typeschema, typekind, encoder, decoder, c_encoder, c_decoder, - fmt, xformat) + base_codec, fmt, xformat) codec_set = True if not codec_set: @@ -829,7 +863,7 @@ cdef register_core_codec(uint32_t oid, codec = Codec(oid) codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat, - encode, decode, None, None, None, None, None, None, 0) + encode, decode, None, None, None, None, None, None, None, 0) cpython.Py_INCREF(codec) # immortalize if format == PG_FORMAT_BINARY: @@ -853,7 +887,7 @@ cdef register_extra_codec(str name, codec = Codec(INVALIDOID) codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT, - encode, decode, None, None, None, None, None, None, 0) + encode, decode, None, None, None, None, None, None, None, 0) EXTRA_CODECS[name, format] = codec diff --git a/asyncpg/protocol/settings.pxd b/asyncpg/protocol/settings.pxd index 41131cdc..0a1a5f6f 100644 --- a/asyncpg/protocol/settings.pxd +++ b/asyncpg/protocol/settings.pxd @@ -18,7 +18,7 @@ cdef class ConnectionSettings(pgproto.CodecContext): cpdef get_text_codec(self) cpdef inline register_data_types(self, types) cpdef inline add_python_codec( - self, typeoid, typename, typeschema, typekind, encoder, + self, typeoid, typename, typeschema, typeinfos, typekind, encoder, decoder, format) cpdef inline remove_python_codec( self, typeoid, typename, typeschema) diff --git a/asyncpg/protocol/settings.pyx b/asyncpg/protocol/settings.pyx index b4cfa399..8e6591b9 100644 --- a/asyncpg/protocol/settings.pyx +++ b/asyncpg/protocol/settings.pyx @@ -36,7 +36,8 @@ cdef class ConnectionSettings(pgproto.CodecContext): self._data_codecs.add_types(types) cpdef inline add_python_codec(self, typeoid, typename, typeschema, - typekind, encoder, decoder, format): + typeinfos, typekind, encoder, decoder, + format): cdef: ServerDataFormat _format ClientExchangeFormat xformat @@ -57,7 +58,8 @@ cdef class ConnectionSettings(pgproto.CodecContext): )) self._data_codecs.add_python_codec(typeoid, typename, typeschema, - typekind, encoder, decoder, + typekind, typeinfos, + encoder, decoder, _format, xformat) cpdef inline remove_python_codec(self, typeoid, typename, typeschema): diff --git a/docs/usage.rst b/docs/usage.rst index a6c62b41..82a7a370 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -216,7 +216,46 @@ JSON values using the :mod:`json ` module. finally: await conn.close() - asyncio.get_event_loop().run_until_complete(main()) + asyncio.run(main()) + + +Example: complex types +~~~~~~~~~~~~~~~~~~~~~~ + +The example below shows how to configure asyncpg to encode and decode +Python :class:`complex ` values to a custom composite +type in PostgreSQL. + +.. code-block:: python + + import asyncio + import asyncpg + + + async def main(): + conn = await asyncpg.connect() + + try: + await conn.execute( + ''' + CREATE TYPE mycomplex AS ( + r float, + i float + );''' + ) + await conn.set_type_codec( + 'complex', + encoder=lambda x: (x.real, x.imag), + decoder=lambda t: complex(t[0], t[1]), + format='tuple', + ) + + res = await conn.fetchval('SELECT $1::mycomplex', (1+2j)) + + finally: + await conn.close() + + asyncio.run(main()) Example: automatic conversion of PostGIS types @@ -274,7 +313,7 @@ will work. finally: await conn.close() - asyncio.get_event_loop().run_until_complete(main()) + asyncio.run(main()) Example: decoding numeric columns as floats diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 918e01d5..bffb2f1a 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1212,28 +1212,11 @@ def hstore_encoder(obj): self.assertEqual(at[0].name, 'result') self.assertEqual(at[0].type, pt[0]) - err = 'cannot use custom codec on non-scalar type public._hstore' + err = 'cannot use custom codec on type public._hstore' with self.assertRaisesRegex(asyncpg.InterfaceError, err): await self.con.set_type_codec('_hstore', encoder=hstore_encoder, decoder=hstore_decoder) - - await self.con.execute(''' - CREATE TYPE mytype AS (a int); - ''') - - try: - err = 'cannot use custom codec on non-scalar type ' + \ - 'public.mytype' - with self.assertRaisesRegex(asyncpg.InterfaceError, err): - await self.con.set_type_codec( - 'mytype', encoder=hstore_encoder, - decoder=hstore_decoder) - finally: - await self.con.execute(''' - DROP TYPE mytype; - ''') - finally: await self.con.execute(''' DROP EXTENSION hstore @@ -1546,6 +1529,53 @@ def _decoder(value): finally: await conn.close() + async def test_custom_codec_composite_tuple(self): + await self.con.execute(''' + CREATE TYPE mycomplex AS (r float, i float); + ''') + + try: + await self.con.set_type_codec( + 'mycomplex', + encoder=lambda x: (x.real, x.imag), + decoder=lambda t: complex(t[0], t[1]), + format='tuple', + ) + + num = complex('1+2j') + + res = await self.con.fetchval( + 'SELECT $1::mycomplex', + num, + ) + + self.assertEqual(num, res) + + finally: + await self.con.execute(''' + DROP TYPE mycomplex; + ''') + + async def test_custom_codec_composite_non_tuple(self): + await self.con.execute(''' + CREATE TYPE mycomplex AS (r float, i float); + ''') + + try: + with self.assertRaisesRegex( + asyncpg.UnsupportedClientFeatureError, + "only tuple-format codecs can be used on composite types", + ): + await self.con.set_type_codec( + 'mycomplex', + encoder=lambda x: (x.real, x.imag), + decoder=lambda t: complex(t[0], t[1]), + ) + finally: + await self.con.execute(''' + DROP TYPE mycomplex; + ''') + async def test_timetz_encoding(self): try: async with self.con.transaction():