Skip to content

Commit

Permalink
GH-34884: [Python]: Support pickling pyarrow.dataset PartitioningFact…
Browse files Browse the repository at this point in the history
…ory objects (#36550)

### Rationale for this change

#36462 already added support for pickling Partitioning objects, but not yet the PartitioningFactory objects.

The problem for PartitioningFactory is that we currently don't really expose the full class hierarchy in python, just the base class PartitioningFactory. We also don't expose creating those factory objects, except through the `discover` methods of the Partitioning classes. 
I think it would be nice to keep this minimal binding, but that means if we want to make them serializable with pickle, we need another way to do that (and if we don't want to add custom code for serialization on the C++ side). 

In this PR, I went for the route of essentially storing the constructor (the discover static method) and the arguments that were passed to the constructor, on the factory object, so we can use this info for pickling. Not the nicest code, but the simplest solution I could think of.

### Are these changes tested?

Yes
* Closes: #34884

Authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
jorisvandenbossche authored Jul 10, 2023
1 parent 3d00668 commit a63ead7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 25 deletions.
5 changes: 4 additions & 1 deletion python/pyarrow/_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,14 @@ cdef class PartitioningFactory(_Weakrefable):
cdef:
shared_ptr[CPartitioningFactory] wrapped
CPartitioningFactory* factory
object constructor
object options

cdef init(self, const shared_ptr[CPartitioningFactory]& sp)

@staticmethod
cdef wrap(const shared_ptr[CPartitioningFactory]& sp)
cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
object constructor, object options)

cdef inline shared_ptr[CPartitioningFactory] unwrap(self)

Expand Down
39 changes: 34 additions & 5 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2374,16 +2374,22 @@ cdef class PartitioningFactory(_Weakrefable):
self.factory = sp.get()

@staticmethod
cdef wrap(const shared_ptr[CPartitioningFactory]& sp):
cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
object constructor, object options):
cdef PartitioningFactory self = PartitioningFactory.__new__(
PartitioningFactory
)
self.init(sp)
self.constructor = constructor
self.options = options
return self

cdef inline shared_ptr[CPartitioningFactory] unwrap(self):
return self.wrapped

def __reduce__(self):
return self.constructor, self.options

@property
def type_name(self):
return frombytes(self.factory.type_name())
Expand Down Expand Up @@ -2454,6 +2460,10 @@ cdef class KeyValuePartitioning(Partitioning):
return res


def _constructor_directory_partitioning_factory(*args):
return DirectoryPartitioning.discover(*args)


cdef class DirectoryPartitioning(KeyValuePartitioning):
"""
A Partitioning based on a specified Schema.
Expand Down Expand Up @@ -2571,7 +2581,15 @@ cdef class DirectoryPartitioning(KeyValuePartitioning):
c_options.segment_encoding = _get_segment_encoding(segment_encoding)

return PartitioningFactory.wrap(
CDirectoryPartitioning.MakeFactory(c_field_names, c_options))
CDirectoryPartitioning.MakeFactory(c_field_names, c_options),
_constructor_directory_partitioning_factory,
(field_names, infer_dictionary, max_partition_dictionary_size,
schema, segment_encoding)
)


def _constructor_hive_partitioning_factory(*args):
return HivePartitioning.discover(*args)


cdef class HivePartitioning(KeyValuePartitioning):
Expand Down Expand Up @@ -2714,7 +2732,15 @@ cdef class HivePartitioning(KeyValuePartitioning):
c_options.segment_encoding = _get_segment_encoding(segment_encoding)

return PartitioningFactory.wrap(
CHivePartitioning.MakeFactory(c_options))
CHivePartitioning.MakeFactory(c_options),
_constructor_hive_partitioning_factory,
(infer_dictionary, max_partition_dictionary_size, null_fallback,
schema, segment_encoding),
)


def _constructor_filename_partitioning_factory(*args):
return FilenamePartitioning.discover(*args)


cdef class FilenamePartitioning(KeyValuePartitioning):
Expand Down Expand Up @@ -2823,7 +2849,10 @@ cdef class FilenamePartitioning(KeyValuePartitioning):
c_options.segment_encoding = _get_segment_encoding(segment_encoding)

return PartitioningFactory.wrap(
CFilenamePartitioning.MakeFactory(c_field_names, c_options))
CFilenamePartitioning.MakeFactory(c_field_names, c_options),
_constructor_filename_partitioning_factory,
(field_names, infer_dictionary, schema, segment_encoding)
)


cdef class DatasetFactory(_Weakrefable):
Expand Down Expand Up @@ -2988,7 +3017,7 @@ cdef class FileSystemFactoryOptions(_Weakrefable):
c_factory = self.options.partitioning.factory()
if c_factory.get() == nullptr:
return None
return PartitioningFactory.wrap(c_factory)
return PartitioningFactory.wrap(c_factory, None, None)

@partitioning_factory.setter
def partitioning_factory(self, PartitioningFactory value):
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/_dataset_parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ cdef class ParquetFactoryOptions(_Weakrefable):
c_factory = self.options.partitioning.factory()
if c_factory.get() == nullptr:
return None
return PartitioningFactory.wrap(c_factory)
return PartitioningFactory.wrap(c_factory, None, None)

@partitioning_factory.setter
def partitioning_factory(self, PartitioningFactory value):
Expand Down
59 changes: 41 additions & 18 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,12 +1642,15 @@ def test_fragments_repr(tempdir, dataset):


@pytest.mark.parquet
def test_partitioning_factory(mockfs):
@pytest.mark.parametrize(
"pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
def test_partitioning_factory(mockfs, pickled):
paths_or_selector = fs.FileSelector('subdir', recursive=True)
format = ds.ParquetFileFormat()

options = ds.FileSystemFactoryOptions('subdir')
partitioning_factory = ds.DirectoryPartitioning.discover(['group', 'key'])
partitioning_factory = pickled(partitioning_factory)
assert isinstance(partitioning_factory, ds.PartitioningFactory)
options.partitioning_factory = partitioning_factory

Expand All @@ -1673,13 +1676,16 @@ def test_partitioning_factory(mockfs):

@pytest.mark.parquet
@pytest.mark.parametrize('infer_dictionary', [False, True])
def test_partitioning_factory_dictionary(mockfs, infer_dictionary):
@pytest.mark.parametrize(
"pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
def test_partitioning_factory_dictionary(mockfs, infer_dictionary, pickled):
paths_or_selector = fs.FileSelector('subdir', recursive=True)
format = ds.ParquetFileFormat()
options = ds.FileSystemFactoryOptions('subdir')

options.partitioning_factory = ds.DirectoryPartitioning.discover(
partitioning_factory = ds.DirectoryPartitioning.discover(
['group', 'key'], infer_dictionary=infer_dictionary)
options.partitioning_factory = pickled(partitioning_factory)

factory = ds.FileSystemDatasetFactory(
mockfs, paths_or_selector, format, options)
Expand All @@ -1703,7 +1709,9 @@ def test_partitioning_factory_dictionary(mockfs, infer_dictionary):
assert inferred_schema.field('key').type == pa.string()


def test_partitioning_factory_segment_encoding():
@pytest.mark.parametrize(
"pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
def test_partitioning_factory_segment_encoding(pickled):
mockfs = fs._MockFileSystem()
format = ds.IpcFileFormat()
schema = pa.schema([("i64", pa.int64())])
Expand All @@ -1726,8 +1734,9 @@ def test_partitioning_factory_segment_encoding():
# Directory
selector = fs.FileSelector("directory", recursive=True)
options = ds.FileSystemFactoryOptions("directory")
options.partitioning_factory = ds.DirectoryPartitioning.discover(
partitioning_factory = ds.DirectoryPartitioning.discover(
schema=partition_schema)
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
inferred_schema = factory.inspect()
assert inferred_schema == full_schema
Expand All @@ -1736,24 +1745,27 @@ def test_partitioning_factory_segment_encoding():
})
assert actual[0][0].as_py() == 1620086400

options.partitioning_factory = ds.DirectoryPartitioning.discover(
partitioning_factory = ds.DirectoryPartitioning.discover(
["date", "string"], segment_encoding="none")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("date") == "2021-05-04 00%3A00%3A00") &
(ds.field("string") == "%24"))

options.partitioning = ds.DirectoryPartitioning(
partitioning = ds.DirectoryPartitioning(
string_partition_schema, segment_encoding="none")
options.partitioning = pickled(partitioning)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("date") == "2021-05-04 00%3A00%3A00") &
(ds.field("string") == "%24"))

options.partitioning_factory = ds.DirectoryPartitioning.discover(
partitioning_factory = ds.DirectoryPartitioning.discover(
schema=partition_schema, segment_encoding="none")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
with pytest.raises(pa.ArrowInvalid,
match="Could not cast segments for partition field"):
Expand All @@ -1762,8 +1774,9 @@ def test_partitioning_factory_segment_encoding():
# Hive
selector = fs.FileSelector("hive", recursive=True)
options = ds.FileSystemFactoryOptions("hive")
options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema)
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
inferred_schema = factory.inspect()
assert inferred_schema == full_schema
Expand All @@ -1772,8 +1785,9 @@ def test_partitioning_factory_segment_encoding():
})
assert actual[0][0].as_py() == 1620086400

options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
segment_encoding="none")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
Expand All @@ -1788,15 +1802,18 @@ def test_partitioning_factory_segment_encoding():
(ds.field("date") == "2021-05-04 00%3A00%3A00") &
(ds.field("string") == "%24"))

options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema, segment_encoding="none")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
with pytest.raises(pa.ArrowInvalid,
match="Could not cast segments for partition field"):
inferred_schema = factory.inspect()


def test_partitioning_factory_hive_segment_encoding_key_encoded():
@pytest.mark.parametrize(
"pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
def test_partitioning_factory_hive_segment_encoding_key_encoded(pickled):
mockfs = fs._MockFileSystem()
format = ds.IpcFileFormat()
schema = pa.schema([("i64", pa.int64())])
Expand Down Expand Up @@ -1825,8 +1842,9 @@ def test_partitioning_factory_hive_segment_encoding_key_encoded():
# Hive
selector = fs.FileSelector("hive", recursive=True)
options = ds.FileSystemFactoryOptions("hive")
options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema)
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
inferred_schema = factory.inspect()
assert inferred_schema == full_schema
Expand All @@ -1835,40 +1853,45 @@ def test_partitioning_factory_hive_segment_encoding_key_encoded():
})
assert actual[0][0].as_py() == 1620086400

options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
segment_encoding="uri")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test'; date") == "2021-05-04 00:00:00") &
(ds.field("test';[ string'") == "$"))

options.partitioning = ds.HivePartitioning(
partitioning = ds.HivePartitioning(
string_partition_schema, segment_encoding="uri")
options.partitioning = pickled(partitioning)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test'; date") == "2021-05-04 00:00:00") &
(ds.field("test';[ string'") == "$"))

options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
segment_encoding="none")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test%27%3B%20date") == "2021-05-04 00%3A00%3A00") &
(ds.field("test%27%3B%5B%20string%27") == "%24"))

options.partitioning = ds.HivePartitioning(
partitioning = ds.HivePartitioning(
string_partition_schema_en, segment_encoding="none")
options.partitioning = pickled(partitioning)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test%27%3B%20date") == "2021-05-04 00%3A00%3A00") &
(ds.field("test%27%3B%5B%20string%27") == "%24"))

options.partitioning_factory = ds.HivePartitioning.discover(
partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema_en, segment_encoding="none")
options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
with pytest.raises(pa.ArrowInvalid,
match="Could not cast segments for partition field"):
Expand Down

0 comments on commit a63ead7

Please sign in to comment.