Skip to content

Commit

Permalink
Fixes to_layout with allow_records=False and allows single-record…
Browse files Browse the repository at this point in the history
… writing to Arrow and Parquet (#1456)

* First, fix #1453.

* Turn Records into length-1 Arrays in v1.

* Implemented record_is_scalar in metadata, but no tests yet.

* Add a test.

* Also to_arrow_table.
  • Loading branch information
jpivarski authored Apr 29, 2022
1 parent 4e6a758 commit de7cae8
Show file tree
Hide file tree
Showing 17 changed files with 151 additions and 26 deletions.
21 changes: 19 additions & 2 deletions src/awkward/_v2/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ def __init__(
mask_parameters,
node_parameters,
record_is_tuple,
record_is_scalar,
):
self._mask_type = mask_type
self._node_type = node_type
self._mask_parameters = mask_parameters
self._node_parameters = node_parameters
self._record_is_tuple = record_is_tuple
self._record_is_scalar = record_is_scalar
super().__init__(storage_type, "awkward")

def __str__(self):
Expand Down Expand Up @@ -127,6 +129,10 @@ def node_parameters(self):
def record_is_tuple(self):
return self._record_is_tuple

@property
def record_is_scalar(self):
return self._record_is_scalar

def __arrow_ext_class__(self):
return AwkwardArrowArray

Expand All @@ -138,6 +144,7 @@ def __arrow_ext_serialize__(self):
"mask_parameters": self._mask_parameters,
"node_parameters": self._node_parameters,
"record_is_tuple": self._record_is_tuple,
"record_is_scalar": self._record_is_scalar,
}
).encode(errors="surrogatescape")

Expand All @@ -151,6 +158,7 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized):
metadata["mask_parameters"],
metadata["node_parameters"],
metadata["record_is_tuple"],
metadata["record_is_scalar"],
)

@property
Expand All @@ -162,7 +170,7 @@ def num_fields(self):
return self.storage_type.num_fields

pyarrow.register_extension_type(
AwkwardArrowType(pyarrow.null(), None, None, None, None, None)
AwkwardArrowType(pyarrow.null(), None, None, None, None, None, None)
)

# order is important; _string_like[:2] vs _string_like[::2]
Expand Down Expand Up @@ -861,7 +869,9 @@ def form_popbuffers(awkwardarrow_type, storage_type):
)


def to_awkwardarrow_type(storage_type, use_extensionarray, mask, node):
def to_awkwardarrow_type(
storage_type, use_extensionarray, record_is_scalar, mask, node
):
if use_extensionarray:
return AwkwardArrowType(
storage_type,
Expand All @@ -870,6 +880,7 @@ def to_awkwardarrow_type(storage_type, use_extensionarray, mask, node):
None if mask is None else mask.parameters,
None if node is None else node.parameters,
node.is_tuple if isinstance(node, ak._v2.contents.RecordArray) else None,
record_is_scalar,
)
else:
return storage_type
Expand Down Expand Up @@ -929,6 +940,7 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):
else:
record_is_optiontype = False
optiontype_fields = []
record_is_scalar = False
optiontype_parameters = None
recordtype_parameters = None
if (
Expand All @@ -940,6 +952,8 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):
(value,) = x.values()
if key == "optiontype_fields":
optiontype_fields = value
elif key == "record_is_scalar":
record_is_scalar = value
elif key in (
"UnmaskedArray",
"BitMaskedArray",
Expand Down Expand Up @@ -975,6 +989,9 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):
parameters=recordtype_parameters,
)

if record_is_scalar:
return out._getitem_at(0)

if record_is_optiontype and record_mask is None and generate_bitmasks:
record_mask = numpy.zeros(len(out), dtype=np.bool_)

Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ def to_arrow(
categorical_as_dictionary=False,
extensionarray=True,
count_nulls=True,
record_is_scalar=False,
):
import awkward._v2._connect.pyarrow

Expand All @@ -1288,6 +1289,7 @@ def to_arrow(
"categorical_as_dictionary": categorical_as_dictionary,
"extensionarray": extensionarray,
"count_nulls": count_nulls,
"record_is_scalar": record_is_scalar,
},
)

Expand Down
6 changes: 5 additions & 1 deletion src/awkward/_v2/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
if options["emptyarray_to"] is None:
return pyarrow.Array.from_buffers(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
pyarrow.null(), options["extensionarray"], mask_node, self
pyarrow.null(),
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
length,
[
Expand Down
6 changes: 5 additions & 1 deletion src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
if options["extensionarray"]:
return ak._v2._connect.pyarrow.AwkwardArrowArray.from_storage(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
out.type, options["extensionarray"], mask_node, self
out.type,
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
out,
)
Expand Down
12 changes: 10 additions & 2 deletions src/awkward/_v2/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,7 +1953,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):

return pyarrow.Array.from_buffers(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
string_type, options["extensionarray"], mask_node, self
string_type,
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
length,
[
Expand All @@ -1979,7 +1983,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):

return pyarrow.Array.from_buffers(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
list_type, options["extensionarray"], mask_node, self
list_type,
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
length,
[
Expand Down
6 changes: 5 additions & 1 deletion src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):

return pyarrow.Array.from_buffers(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
storage_type, options["extensionarray"], mask_node, self
storage_type,
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
length,
[
Expand Down
6 changes: 5 additions & 1 deletion src/awkward/_v2/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):

return pyarrow.Array.from_buffers(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
types, options["extensionarray"], mask_node, self
types,
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
length,
[ak._v2._connect.pyarrow.to_validbits(validbytes)],
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
ak._v2._connect.pyarrow.to_awkwardarrow_type(
pyarrow.binary(self._size),
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
Expand All @@ -1125,6 +1126,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
ak._v2._connect.pyarrow.to_awkwardarrow_type(
pyarrow.list_(content_type, self._size),
options["extensionarray"],
options["record_is_scalar"],
mask_node,
self,
),
Expand Down
6 changes: 5 additions & 1 deletion src/awkward/_v2/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,11 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):

return pyarrow.Array.from_buffers(
ak._v2._connect.pyarrow.to_awkwardarrow_type(
types, options["extensionarray"], None, self
types,
options["extensionarray"],
options["record_is_scalar"],
None,
self,
),
nptags.shape[0],
[
Expand Down
3 changes: 3 additions & 0 deletions src/awkward/_v2/operations/convert/ak_from_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,7 @@ def _impl(array, generate_bitmasks, highlevel, behavior):
if awkwardarrow_type.mask_type in (None, "IndexedArray"):
out = awkward._v2._connect.pyarrow.remove_optiontype(out)

if awkwardarrow_type.record_is_scalar:
out = out._getitem_at(0)

return ak._v2._util.wrap(out, behavior, highlevel)
2 changes: 2 additions & 0 deletions src/awkward/_v2/operations/convert/ak_from_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def _load(
return ak._v2.operations.convert.ak_from_buffers._impl(
subform, 0, _DictOfEmptyBuffers(), "", numpy, highlevel, behavior
)
elif len(arrays) == 1 and isinstance(arrays[0], ak._v2.record.Record):
return ak._v2._util.wrap(arrays[0], behavior, highlevel)
else:
return ak._v2.operations.structure.ak_concatenate._impl(
arrays, 0, True, True, highlevel, behavior
Expand Down
8 changes: 7 additions & 1 deletion src/awkward/_v2/operations/convert/ak_to_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ def _impl(
count_nulls,
):
layout = ak._v2.operations.convert.to_layout(
array, allow_record=False, allow_other=False
array, allow_record=True, allow_other=False
)
if isinstance(layout, ak._v2.record.Record):
layout = layout.array[layout.at : layout.at + 1]
record_is_scalar = True
else:
record_is_scalar = False

return layout.to_arrow(
list_to32=list_to32,
Expand All @@ -108,4 +113,5 @@ def _impl(
categorical_as_dictionary=categorical_as_dictionary,
extensionarray=extensionarray,
count_nulls=count_nulls,
record_is_scalar=record_is_scalar,
)
14 changes: 12 additions & 2 deletions src/awkward/_v2/operations/convert/ak_to_arrow_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,13 @@ def _impl(
from awkward._v2._connect.pyarrow import pyarrow

layout = ak._v2.operations.convert.to_layout(
array, allow_record=False, allow_other=False
array, allow_record=True, allow_other=False
)
if isinstance(layout, ak._v2.record.Record):
layout = layout.array[layout.at : layout.at + 1]
record_is_scalar = True
else:
record_is_scalar = False

check = [layout]
while check[-1].is_OptionType or check[-1].is_IndexedType:
Expand All @@ -121,6 +126,7 @@ def _impl(
categorical_as_dictionary=categorical_as_dictionary,
extensionarray=extensionarray,
count_nulls=count_nulls,
record_is_scalar=record_is_scalar,
)
)
pafields.append(
Expand All @@ -131,7 +137,10 @@ def _impl(
if check[-1].contents[check[-1].field_to_index(name)].is_OptionType:
optiontype_fields.append(name)

parameters = [{"optiontype_fields": optiontype_fields}]
parameters = [
{"optiontype_fields": optiontype_fields},
{"record_is_scalar": record_is_scalar},
]
for x in check:
parameters.append(
{ak._v2._util.direct_Content_subclass(x).__name__: x._parameters}
Expand All @@ -147,6 +156,7 @@ def _impl(
categorical_as_dictionary=categorical_as_dictionary,
extensionarray=extensionarray,
count_nulls=count_nulls,
record_is_scalar=record_is_scalar,
)
)
pafields.append(
Expand Down
11 changes: 8 additions & 3 deletions src/awkward/_v2/operations/convert/ak_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,21 @@ def _impl(array, allow_record, allow_other, numpytype):
elif isinstance(array, ak._v2.record.Record):
if not allow_record:
raise ak._v2._util.error(
TypeError("ak._v2.Record objects are not allowed here")
TypeError("ak._v2.Record objects are not allowed in this function")
)
else:
return array

elif isinstance(array, ak._v2.highlevel.Array):
return array.layout

elif allow_record and isinstance(array, ak._v2.highlevel.Record):
return array.layout
elif isinstance(array, ak._v2.highlevel.Record):
if not allow_record:
raise ak._v2._util.error(
TypeError("ak._v2.Record objects are not allowed in this function")
)
else:
return array.layout

# elif isinstance(array, ak._v2.highlevel.ArrayBuilder):
# return array.snapshot().layout
Expand Down
13 changes: 9 additions & 4 deletions src/awkward/_v2/operations/convert/ak_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def to_parquet(
)
fsspec = awkward._v2._connect.pyarrow.import_fsspec("ak.to_parquet")

if isinstance(data, Iterable) and not isinstance(data, Sized):
if isinstance(data, (ak._v2.highlevel.Record, ak._v2.record.Record)):
iterator = iter([data])
elif isinstance(data, Iterable) and not isinstance(data, Sized):
iterator = iter(data)
elif isinstance(data, Iterable):
iterator = iter([data])
Expand All @@ -54,7 +56,7 @@ def to_parquet(
row_group = 0
array = next(iterator)
layout = ak._v2.operations.convert.ak_to_layout.to_layout(
array, allow_record=False, allow_other=False
array, allow_record=True, allow_other=False
)
table = ak._v2.operations.convert.ak_to_arrow_table._impl(
layout,
Expand All @@ -77,7 +79,10 @@ def to_parquet(
else:
column_prefix = ()

form = layout.form
if isinstance(layout, ak._v2.record.Record):
form = layout.array.form
else:
form = layout.form

def parquet_columns(specifier, only=None):
if specifier is None:
Expand Down Expand Up @@ -200,7 +205,7 @@ def parquet_columns(specifier, only=None):
except StopIteration:
break
layout = ak._v2.operations.convert.ak_to_layout.to_layout(
array, allow_record=False, allow_other=False
array, allow_record=True, allow_other=False
)
table = ak._v2.operations.convert.ak_to_arrow_table._impl(
layout,
Expand Down
Loading

0 comments on commit de7cae8

Please sign in to comment.