diff --git a/docs/source/python/extending_types.rst b/docs/source/python/extending_types.rst index b9e875ceebc74..ee92cebcb549c 100644 --- a/docs/source/python/extending_types.rst +++ b/docs/source/python/extending_types.rst @@ -68,34 +68,43 @@ message). See the :ref:`format_metadata_extension_types` section of the metadata specification for more details. -Pyarrow allows you to define such extension types from Python. - -There are currently two ways: - -* Subclassing :class:`PyExtensionType`: the (de)serialization is based on pickle. - This is a good option for an extension type that is only used from Python. -* Subclassing :class:`ExtensionType`: this allows to give a custom - Python-independent name and serialized metadata, that can potentially be - recognized by other (non-Python) Arrow implementations such as PySpark. +Pyarrow allows you to define such extension types from Python by subclassing +:class:`ExtensionType` and giving the derived class its own extension name +and serialization mechanism. The extension name and serialized metadata +can potentially be recognized by other (non-Python) Arrow implementations +such as PySpark. For example, we could define a custom UUID type for 128-bit numbers which can -be represented as ``FixedSizeBinary`` type with 16 bytes. -Using the first approach, we create a ``UuidType`` subclass, and implement the -``__reduce__`` method to ensure the class can be properly pickled:: +be represented as ``FixedSizeBinary`` type with 16 bytes:: - class UuidType(pa.PyExtensionType): + class UuidType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.binary(16)) + super().__init__(pa.binary(16), "my_package.uuid") + + def __arrow_ext_serialize__(self): + # Since we don't have a parameterized type, we don't need extra + # metadata to be deserialized + return b'' - def __reduce__(self): - return UuidType, () + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + # Sanity checks, not required but illustrate the method signature. + assert storage_type == pa.binary(16) + assert serialized == b'' + # Return an instance of this subclass given the serialized + # metadata. + return UuidType() + +The special methods ``__arrow_ext_serialize__`` and ``__arrow_ext_deserialize__`` +define the serialization of an extension type instance. For non-parametric +types such as the above, the serialization payload can be left empty. This can now be used to create arrays and tables holding the extension type:: >>> uuid_type = UuidType() >>> uuid_type.extension_name - 'arrow.py_extension_type' + 'my_package.uuid' >>> uuid_type.storage_type FixedSizeBinaryType(fixed_size_binary[16]) @@ -112,8 +121,11 @@ This can now be used to create arrays and tables holding the extension type:: ] This array can be included in RecordBatches, sent over IPC and received in -another Python process. The custom UUID type will be preserved there, as long -as the definition of the class is available (the type can be unpickled). +another Python process. The receiving process must explicitly register the +extension type for deserialization, otherwise it will fall back to the +storage type:: + + >>> pa.register_extension_type(UuidType()) For example, creating a RecordBatch and writing it to a stream using the IPC protocol:: @@ -129,43 +141,12 @@ and then reading it back yields the proper type:: >>> with pa.ipc.open_stream(buf) as reader: ... result = reader.read_all() >>> result.column('ext').type - UuidType(extension) - -We can define the same type using the other option:: - - class UuidType(pa.ExtensionType): - - def __init__(self): - pa.ExtensionType.__init__(self, pa.binary(16), "my_package.uuid") - - def __arrow_ext_serialize__(self): - # since we don't have a parameterized type, we don't need extra - # metadata to be deserialized - return b'' - - @classmethod - def __arrow_ext_deserialize__(self, storage_type, serialized): - # return an instance of this subclass given the serialized - # metadata. - return UuidType() - -This is a slightly longer implementation (you need to implement the special -methods ``__arrow_ext_serialize__`` and ``__arrow_ext_deserialize__``), and the -extension type needs to be registered to be received through IPC (using -:func:`register_extension_type`), but it has -now a unique name:: - - >>> uuid_type = UuidType() - >>> uuid_type.extension_name - 'my_package.uuid' - - >>> pa.register_extension_type(uuid_type) + UuidType(FixedSizeBinaryType(fixed_size_binary[16])) The receiving application doesn't need to be Python but can still recognize -the extension type as a "uuid" type, if it has implemented its own extension -type to receive it. -If the type is not registered in the receiving application, it will fall back -to the storage type. +the extension type as a "my_package.uuid" type, if it has implemented its own +extension type to receive it. If the type is not registered in the receiving +application, it will fall back to the storage type. Parameterized extension type ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -187,7 +168,7 @@ of the given frequency since 1970. # attributes need to be set first before calling # super init (as that calls serialize) self._freq = freq - pa.ExtensionType.__init__(self, pa.int64(), 'my_package.period') + super().__init__(pa.int64(), 'my_package.period') @property def freq(self): @@ -198,7 +179,7 @@ of the given frequency since 1970. @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized): - # return an instance of this subclass given the serialized + # Return an instance of this subclass given the serialized # metadata. serialized = serialized.decode() assert serialized.startswith("freq=") @@ -209,31 +190,10 @@ Here, we ensure to store all information in the serialized metadata that is needed to reconstruct the instance (in the ``__arrow_ext_deserialize__`` class method), in this case the frequency string. -Note that, once created, the data type instance is considered immutable. If, -in the example above, the ``freq`` parameter would change after instantiation, -the reconstruction of the type instance after IPC will be incorrect. +Note that, once created, the data type instance is considered immutable. In the example above, the ``freq`` parameter is therefore stored in a private attribute with a public read-only property to access it. -Parameterized extension types are also possible using the pickle-based type -subclassing :class:`PyExtensionType`. The equivalent example for the period -data type from above would look like:: - - class PeriodType(pa.PyExtensionType): - - def __init__(self, freq): - self._freq = freq - pa.PyExtensionType.__init__(self, pa.int64()) - - @property - def freq(self): - return self._freq - - def __reduce__(self): - return PeriodType, (self.freq,) - -Also the storage type does not need to be fixed but can be parameterized. - Custom extension array class ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -252,12 +212,16 @@ the data as a 2-D Numpy array ``(N, 3)`` without any copy:: return self.storage.flatten().to_numpy().reshape((-1, 3)) - class Point3DType(pa.PyExtensionType): + class Point3DType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.list_(pa.float32(), 3)) + super().__init__(pa.list_(pa.float32(), 3), "my_package.Point3DType") - def __reduce__(self): - return Point3DType, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return Point3DType() def __arrow_ext_class__(self): return Point3DArray @@ -289,11 +253,8 @@ The additional methods in the extension class are then available to the user:: This array can be sent over IPC, received in another Python process, and the custom -extension array class will be preserved (as long as the definitions of the classes above -are available). - -The same ``__arrow_ext_class__`` specialization can be used with custom types defined -by subclassing :class:`ExtensionType`. +extension array class will be preserved (as long as the receiving process registers +the extension type using :func:`register_extension_type` before reading the IPC data). Custom scalar conversion ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -304,18 +265,24 @@ If you want scalars of your custom extension type to convert to a custom type wh For example, if we wanted the above example 3D point type to return a custom 3D point class instead of a list, we would implement:: + from collections import namedtuple + Point3D = namedtuple("Point3D", ["x", "y", "z"]) class Point3DScalar(pa.ExtensionScalar): def as_py(self) -> Point3D: return Point3D(*self.value.as_py()) - class Point3DType(pa.PyExtensionType): + class Point3DType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.list_(pa.float32(), 3)) + super().__init__(pa.list_(pa.float32(), 3), "my_package.Point3DType") - def __reduce__(self): - return Point3DType, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return Point3DType() def __arrow_ext_scalar_class__(self): return Point3DScalar diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index 55bab4359bffc..a9c17cc100cb4 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import contextlib import ctypes import gc @@ -51,18 +52,33 @@ def PyCapsule_IsValid(capsule, name): return ctypes.pythonapi.PyCapsule_IsValid(ctypes.py_object(capsule), name) == 1 -class ParamExtType(pa.PyExtensionType): +@contextlib.contextmanager +def registered_extension_type(ext_type): + pa.register_extension_type(ext_type) + try: + yield + finally: + pa.unregister_extension_type(ext_type.extension_name) + + +class ParamExtType(pa.ExtensionType): def __init__(self, width): self._width = width - pa.PyExtensionType.__init__(self, pa.binary(width)) + super().__init__(pa.binary(width), + "pyarrow.tests.test_cffi.ParamExtType") @property def width(self): return self._width - def __reduce__(self): - return ParamExtType, (self.width,) + def __arrow_ext_serialize__(self): + return str(self.width).encode() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + width = int(serialized.decode()) + return cls(width) def make_schema(): @@ -75,6 +91,12 @@ def make_extension_schema(): metadata={b'key1': b'value1'}) +def make_extension_storage_schema(): + # Should be kept in sync with make_extension_schema + return pa.schema([('ext', ParamExtType(3).storage_type)], + metadata={b'key1': b'value1'}) + + def make_batch(): return pa.record_batch([[[1], [2, 42]]], make_schema()) @@ -204,7 +226,10 @@ def test_export_import_array(): pa.Array._import_from_c(ptr_array, ptr_schema) -def check_export_import_schema(schema_factory): +def check_export_import_schema(schema_factory, expected_schema_factory=None): + if expected_schema_factory is None: + expected_schema_factory = schema_factory + c_schema = ffi.new("struct ArrowSchema*") ptr_schema = int(ffi.cast("uintptr_t", c_schema)) @@ -215,7 +240,7 @@ def check_export_import_schema(schema_factory): assert pa.total_allocated_bytes() > old_allocated # Delete and recreate C++ object from exported pointer schema_new = pa.Schema._import_from_c(ptr_schema) - assert schema_new == schema_factory() + assert schema_new == expected_schema_factory() assert pa.total_allocated_bytes() == old_allocated del schema_new assert pa.total_allocated_bytes() == old_allocated @@ -240,7 +265,13 @@ def test_export_import_schema(): @needs_cffi def test_export_import_schema_with_extension(): - check_export_import_schema(make_extension_schema) + # Extension type is unregistered => the storage type is imported + check_export_import_schema(make_extension_schema, + make_extension_storage_schema) + + # Extension type is registered => the extension type is imported + with registered_extension_type(ParamExtType(1)): + check_export_import_schema(make_extension_schema) @needs_cffi @@ -319,7 +350,8 @@ def test_export_import_batch(): @needs_cffi def test_export_import_batch_with_extension(): - check_export_import_batch(make_extension_batch) + with registered_extension_type(ParamExtType(1)): + check_export_import_batch(make_extension_batch) def _export_import_batch_reader(ptr_stream, reader_factory): diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index ce575d984e41c..a88e20eefe098 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import contextlib import os import shutil import subprocess @@ -29,31 +30,69 @@ import pytest -class TinyIntType(pa.PyExtensionType): +@contextlib.contextmanager +def registered_extension_type(ext_type): + pa.register_extension_type(ext_type) + try: + yield + finally: + pa.unregister_extension_type(ext_type.extension_name) + + +@contextlib.contextmanager +def enabled_auto_load(): + pa.PyExtensionType.set_auto_load(True) + try: + yield + finally: + pa.PyExtensionType.set_auto_load(False) + + +class TinyIntType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int8()) + super().__init__(pa.int8(), 'pyarrow.tests.TinyIntType') - def __reduce__(self): - return TinyIntType, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + assert storage_type == pa.int8() + return cls() -class IntegerType(pa.PyExtensionType): +class IntegerType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int64()) + super().__init__(pa.int64(), 'pyarrow.tests.IntegerType') - def __reduce__(self): - return IntegerType, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + assert storage_type == pa.int64() + return cls() -class IntegerEmbeddedType(pa.PyExtensionType): +class IntegerEmbeddedType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, IntegerType()) + super().__init__(IntegerType(), 'pyarrow.tests.IntegerType') - def __reduce__(self): - return IntegerEmbeddedType, () + def __arrow_ext_serialize__(self): + # XXX pa.BaseExtensionType should expose C++ serialization method + return self.storage_type.__arrow_ext_serialize__() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + deserialized_storage_type = storage_type.__arrow_ext_deserialize__( + serialized) + assert deserialized_storage_type == storage_type + return cls() class UuidScalarType(pa.ExtensionScalar): @@ -61,81 +100,125 @@ def as_py(self): return None if self.value is None else UUID(bytes=self.value.as_py()) -class UuidType(pa.PyExtensionType): +class UuidType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.binary(16)) - - def __reduce__(self): - return UuidType, () + super().__init__(pa.binary(16), 'pyarrow.tests.UuidType') def __arrow_ext_scalar_class__(self): return UuidScalarType + def __arrow_ext_serialize__(self): + return b'' -class UuidType2(pa.PyExtensionType): + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() + + +class UuidType2(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.binary(16)) + super().__init__(pa.binary(16), 'pyarrow.tests.UuidType2') - def __reduce__(self): - return UuidType2, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() -class LabelType(pa.PyExtensionType): +class LabelType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.string()) + super().__init__(pa.string(), 'pyarrow.tests.LabelType') - def __reduce__(self): - return LabelType, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() -class ParamExtType(pa.PyExtensionType): +class ParamExtType(pa.ExtensionType): def __init__(self, width): self._width = width - pa.PyExtensionType.__init__(self, pa.binary(width)) + super().__init__(pa.binary(width), 'pyarrow.tests.ParamExtType') @property def width(self): return self._width - def __reduce__(self): - return ParamExtType, (self.width,) + def __arrow_ext_serialize__(self): + return str(self._width).encode() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + width = int(serialized.decode()) + assert storage_type == pa.binary(width) + return cls(width) -class MyStructType(pa.PyExtensionType): +class MyStructType(pa.ExtensionType): storage_type = pa.struct([('left', pa.int64()), ('right', pa.int64())]) def __init__(self): - pa.PyExtensionType.__init__(self, self.storage_type) + super().__init__(self.storage_type, 'pyarrow.tests.MyStructType') - def __reduce__(self): - return MyStructType, () + def __arrow_ext_serialize__(self): + return b'' + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + assert storage_type == cls.storage_type + return cls() -class MyListType(pa.PyExtensionType): + +class MyListType(pa.ExtensionType): def __init__(self, storage_type): - pa.PyExtensionType.__init__(self, storage_type) + assert isinstance(storage_type, pa.ListType) + super().__init__(storage_type, 'pyarrow.tests.MyListType') - def __reduce__(self): - return MyListType, (self.storage_type,) + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + return cls(storage_type) -class AnnotatedType(pa.PyExtensionType): +class AnnotatedType(pa.ExtensionType): """ Generic extension type that can store any storage type. """ def __init__(self, storage_type, annotation): self.annotation = annotation - super().__init__(storage_type) + super().__init__(storage_type, 'pyarrow.tests.AnnotatedType') + + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + return cls(storage_type) + + +class LegacyIntType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int8()) def __reduce__(self): - return AnnotatedType, (self.storage_type, self.annotation) + return LegacyIntType, () def ipc_write_batch(batch): @@ -153,12 +236,12 @@ def ipc_read_batch(buf): def test_ext_type_basics(): ty = UuidType() - assert ty.extension_name == "arrow.py_extension_type" + assert ty.extension_name == "pyarrow.tests.UuidType" def test_ext_type_str(): ty = IntegerType() - expected = "extension>" + expected = "extension>" assert str(ty) == expected assert pa.DataType.__str__(ty) == expected @@ -223,7 +306,7 @@ def test_uuid_type_pickle(pickle_module): del ty ty = pickle_module.loads(ser) wr = weakref.ref(ty) - assert ty.extension_name == "arrow.py_extension_type" + assert ty.extension_name == "pyarrow.tests.UuidType" del ty assert wr() is None @@ -571,9 +654,9 @@ def test_cast_between_extension_types(): assert tiny_int_arr.type == TinyIntType() # Casting between extension types w/ different storage types not okay. - msg = ("Casting from 'extension>' " + msg = ("Casting from 'extension<.*?>' " "to different extension type " - "'extension>' not permitted. " + "'extension<.*?>' not permitted. " "One can first cast to the storage type, " "then to the extension type." ) @@ -660,53 +743,38 @@ def example_batch(): return pa.RecordBatch.from_arrays([arr], ["exts"]) -def check_example_batch(batch): +def check_example_batch(batch, *, expect_extension): arr = batch.column(0) - assert isinstance(arr, pa.ExtensionArray) - assert arr.type.storage_type == pa.binary(3) - assert arr.storage.to_pylist() == [b"foo", b"bar"] + if expect_extension: + assert isinstance(arr, pa.ExtensionArray) + assert arr.type.storage_type == pa.binary(3) + assert arr.storage.to_pylist() == [b"foo", b"bar"] + else: + assert arr.type == pa.binary(3) + assert arr.to_pylist() == [b"foo", b"bar"] return arr -def test_ipc(): +def test_ipc_unregistered(): batch = example_batch() buf = ipc_write_batch(batch) del batch batch = ipc_read_batch(buf) - arr = check_example_batch(batch) - assert arr.type == ParamExtType(3) + batch.validate(full=True) + check_example_batch(batch, expect_extension=False) -def test_ipc_unknown_type(): - batch = example_batch() - buf = ipc_write_batch(batch) - del batch - - orig_type = ParamExtType - try: - # Simulate the original Python type being unavailable. - # Deserialization should not fail but return a placeholder type. - del globals()['ParamExtType'] +def test_ipc_registered(): + with registered_extension_type(ParamExtType(1)): + batch = example_batch() + buf = ipc_write_batch(batch) + del batch batch = ipc_read_batch(buf) - arr = check_example_batch(batch) - assert isinstance(arr.type, pa.UnknownExtensionType) - - # Can be serialized again - buf2 = ipc_write_batch(batch) - del batch, arr - - batch = ipc_read_batch(buf2) - arr = check_example_batch(batch) - assert isinstance(arr.type, pa.UnknownExtensionType) - finally: - globals()['ParamExtType'] = orig_type - - # Deserialize again with the type restored - batch = ipc_read_batch(buf2) - arr = check_example_batch(batch) - assert arr.type == ParamExtType(3) + batch.validate(full=True) + arr = check_example_batch(batch, expect_extension=True) + assert arr.type == ParamExtType(3) class PeriodArray(pa.ExtensionArray): @@ -930,6 +998,7 @@ def test_parquet_period(tmpdir, registered_period_type): # When reading in, properly create extension type if it is registered result = pq.read_table(filename) + result.validate(full=True) assert result.schema.field("ext").type == period_type assert result.schema.field("ext").metadata == {} # Get the exact array class defined by the registered type. @@ -939,6 +1008,7 @@ def test_parquet_period(tmpdir, registered_period_type): # When the type is not registered, read in as storage type pa.unregister_extension_type(period_type.extension_name) result = pq.read_table(filename) + result.validate(full=True) assert result.schema.field("ext").type == pa.int64() # The extension metadata is present for roundtripping. assert result.schema.field("ext").metadata == { @@ -967,13 +1037,28 @@ def test_parquet_extension_with_nested_storage(tmpdir): filename = tmpdir / 'nested_extension_storage.parquet' pq.write_table(orig_table, filename) + # Unregistered table = pq.read_table(filename) - assert table.column('structs').type == mystruct_array.type - assert table.column('lists').type == mylist_array.type - assert table == orig_table - - with pytest.raises(pa.ArrowInvalid, match='without all of its fields'): - pq.ParquetFile(filename).read(columns=['structs.left']) + table.validate(full=True) + assert table.column('structs').type == struct_array.type + assert table.column('structs').combine_chunks() == struct_array + assert table.column('lists').type == list_array.type + assert table.column('lists').combine_chunks() == list_array + + # Registered + with registered_extension_type(mystruct_array.type): + with registered_extension_type(mylist_array.type): + table = pq.read_table(filename) + table.validate(full=True) + assert table.column('structs').type == mystruct_array.type + assert table.column('lists').type == mylist_array.type + assert table == orig_table + + # Cannot select a subfield of an extension type with + # a struct storage type. + with pytest.raises(pa.ArrowInvalid, + match='without all of its fields'): + pq.ParquetFile(filename).read(columns=['structs.left']) @pytest.mark.parquet @@ -995,8 +1080,14 @@ def test_parquet_nested_extension(tmpdir): pq.write_table(orig_table, filename) table = pq.read_table(filename) - assert table.column(0).type == struct_array.type - assert table == orig_table + table.validate(full=True) + assert table.column(0).type == pa.struct({'ints': pa.int64(), + 'exts': pa.int64()}) + with registered_extension_type(ext_type): + table = pq.read_table(filename) + table.validate(full=True) + assert table.column(0).type == struct_array.type + assert table == orig_table # List of extensions list_array = pa.ListArray.from_arrays([0, 1, None, 3], ext_array) @@ -1006,8 +1097,13 @@ def test_parquet_nested_extension(tmpdir): pq.write_table(orig_table, filename) table = pq.read_table(filename) - assert table.column(0).type == list_array.type - assert table == orig_table + table.validate(full=True) + assert table.column(0).type == pa.list_(pa.int64()) + with registered_extension_type(ext_type): + table = pq.read_table(filename) + table.validate(full=True) + assert table.column(0).type == list_array.type + assert table == orig_table # Large list of extensions list_array = pa.LargeListArray.from_arrays([0, 1, None, 3], ext_array) @@ -1017,8 +1113,13 @@ def test_parquet_nested_extension(tmpdir): pq.write_table(orig_table, filename) table = pq.read_table(filename) - assert table.column(0).type == list_array.type - assert table == orig_table + table.validate(full=True) + assert table.column(0).type == pa.large_list(pa.int64()) + with registered_extension_type(ext_type): + table = pq.read_table(filename) + table.validate(full=True) + assert table.column(0).type == list_array.type + assert table == orig_table @pytest.mark.parquet @@ -1040,8 +1141,12 @@ def test_parquet_extension_nested_in_extension(tmpdir): pq.write_table(orig_table, filename) table = pq.read_table(filename) - assert table.column(0).type == mylist_array.type - assert table == orig_table + assert table.column(0).type == pa.list_(pa.int64()) + with registered_extension_type(mylist_array.type): + with registered_extension_type(inner_ext_array.type): + table = pq.read_table(filename) + assert table.column(0).type == mylist_array.type + assert table == orig_table def test_to_numpy(): @@ -1370,3 +1475,25 @@ def test_tensor_type_is_picklable(pickle_module): def test_tensor_type_str(tensor_type, text): tensor_type_str = tensor_type.__str__() assert text in tensor_type_str + + +def test_legacy_int_type(): + with pytest.warns(FutureWarning, match="PyExtensionType is deprecated"): + ext_ty = LegacyIntType() + arr = pa.array([1, 2, 3], type=ext_ty.storage_type) + ext_arr = pa.ExtensionArray.from_storage(ext_ty, arr) + batch = pa.RecordBatch.from_arrays([ext_arr], names=['ext']) + buf = ipc_write_batch(batch) + + with pytest.warns( + RuntimeWarning, + match="pickle-based deserialization of pyarrow.PyExtensionType " + "subclasses is disabled by default"): + batch = ipc_read_batch(buf) + assert isinstance(batch.column(0).type, pa.UnknownExtensionType) + + with enabled_auto_load(): + with pytest.warns(FutureWarning, match="PyExtensionType is deprecated"): + batch = ipc_read_batch(buf) + assert isinstance(batch.column(0).type, LegacyIntType) + assert batch.column(0) == ext_arr diff --git a/python/pyarrow/tests/test_pandas.py b/python/pyarrow/tests/test_pandas.py index 62a9443953a3d..10eb931592093 100644 --- a/python/pyarrow/tests/test_pandas.py +++ b/python/pyarrow/tests/test_pandas.py @@ -4096,13 +4096,20 @@ def test_array_protocol(): assert result.equals(expected2) -class DummyExtensionType(pa.PyExtensionType): +class DummyExtensionType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int64()) + super().__init__(pa.int64(), + 'pyarrow.tests.test_pandas.DummyExtensionType') - def __reduce__(self): - return DummyExtensionType, () + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + assert storage_type == pa.int64() + return cls() def PandasArray__arrow_array__(self, type=None): @@ -4198,13 +4205,14 @@ def test_convert_to_extension_array(monkeypatch): assert not isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) -class MyCustomIntegerType(pa.PyExtensionType): +class MyCustomIntegerType(pa.ExtensionType): def __init__(self): - pa.PyExtensionType.__init__(self, pa.int64()) + super().__init__(pa.int64(), + 'pyarrow.tests.test_pandas.MyCustomIntegerType') - def __reduce__(self): - return MyCustomIntegerType, () + def __arrow_ext_serialize__(self): + return b'' def to_pandas_dtype(self): return pd.Int64Dtype() diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index d394b803e7fc2..a0ddf09d69423 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1437,7 +1437,10 @@ cdef class ExtensionType(BaseExtensionType): Parameters ---------- storage_type : DataType + The underlying storage type for the extension type. extension_name : str + A unique name distinguishing this extension type. The name will be + used when deserializing IPC data. Examples -------- @@ -1671,60 +1674,22 @@ cdef class FixedShapeTensorType(BaseExtensionType): self.dim_names, self.permutation) +_py_extension_type_auto_load = False + + cdef class PyExtensionType(ExtensionType): """ Concrete base class for Python-defined extension types based on pickle for (de)serialization. + .. warning:: + This class is deprecated and its deserialization is disabled by default. + :class:`ExtensionType` is recommended instead. + Parameters ---------- storage_type : DataType The storage type for which the extension is built. - - Examples - -------- - Define a UuidType extension type subclassing PyExtensionType: - - >>> import pyarrow as pa - >>> class UuidType(pa.PyExtensionType): - ... def __init__(self): - ... pa.PyExtensionType.__init__(self, pa.binary(16)) - ... def __reduce__(self): - ... return UuidType, () - ... - - Create an instance of UuidType extension type: - - >>> uuid_type = UuidType() # doctest: +SKIP - >>> uuid_type # doctest: +SKIP - UuidType(FixedSizeBinaryType(fixed_size_binary[16])) - - Inspect the extension type: - - >>> uuid_type.extension_name # doctest: +SKIP - 'arrow.py_extension_type' - >>> uuid_type.storage_type # doctest: +SKIP - FixedSizeBinaryType(fixed_size_binary[16]) - - Wrap an array as an extension array: - - >>> import uuid - >>> storage_array = pa.array([uuid.uuid4().bytes for _ in range(4)], - ... pa.binary(16)) # doctest: +SKIP - >>> uuid_type.wrap_array(storage_array) # doctest: +SKIP - - [ - ... - ] - - Or do the same with creating an ExtensionArray: - - >>> pa.ExtensionArray.from_storage(uuid_type, - ... storage_array) # doctest: +SKIP - - [ - ... - ] """ def __cinit__(self): @@ -1733,6 +1698,12 @@ cdef class PyExtensionType(ExtensionType): "PyExtensionType") def __init__(self, DataType storage_type): + warnings.warn( + "pyarrow.PyExtensionType is deprecated " + "and will refuse deserialization by default. " + "Instead, please derive from pyarrow.ExtensionType and implement " + "your own serialization mechanism.", + FutureWarning) ExtensionType.__init__(self, storage_type, "arrow.py_extension_type") def __reduce__(self): @@ -1744,6 +1715,17 @@ cdef class PyExtensionType(ExtensionType): @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized): + if not _py_extension_type_auto_load: + warnings.warn( + "pickle-based deserialization of pyarrow.PyExtensionType subclasses " + "is disabled by default; if you only ingest " + "trusted data files, you may re-enable this using " + "`pyarrow.PyExtensionType.set_auto_load(True)`.\n" + "In the future, Python-defined extension subclasses should " + "derive from pyarrow.ExtensionType (not pyarrow.PyExtensionType) " + "and implement their own serialization mechanism.\n", + RuntimeWarning) + return UnknownExtensionType(storage_type, serialized) try: ty = pickle.loads(serialized) except Exception: @@ -1759,6 +1741,22 @@ cdef class PyExtensionType(ExtensionType): .format(ty.storage_type, storage_type)) return ty + # XXX Cython marks extension types as immutable, so cannot expose this + # as a writable class attribute. + @classmethod + def set_auto_load(cls, value): + """ + Enable or disable auto-loading of serialized PyExtensionType instances. + + Parameters + ---------- + value : bool + Whether to enable auto-loading. + """ + global _py_extension_type_auto_load + assert isinstance(value, bool) + _py_extension_type_auto_load = value + cdef class UnknownExtensionType(PyExtensionType): """