Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ refactoring: _getitem_array implementation #959

Merged
merged 13 commits into from
Jul 20, 2021
Merged
64 changes: 47 additions & 17 deletions src/awkward/_v2/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,53 @@ def _repr(self, indent, pre, post):
out.append(post)
return "".join(out)

def bytemask(self):
nplike = self._mask.nplike
bytemask = ak._v2.index.Index8.empty(len(self._mask) * 8, nplike)
self.handle_error(
nplike[
"awkward_BitMaskedArray_to_ByteMaskedArray",
bytemask.dtype.type,
self._mask.dtype.type,
](
bytemask.data,
self._mask.data,
len(self._mask),
self._valid_when,
self._lsb_order,
)
)
return bytemask.data[: self._length]

def toByteMaskedArray(self):
nplike = self._mask.nplike
bytemask = ak._v2.index.Index8.empty(len(self._mask) * 8, nplike)
self.handle_error(
nplike[
"awkward_BitMaskedArray_to_ByteMaskedArray",
bytemask.dtype.type,
self._mask.dtype.type,
](
bytemask.data,
self._mask.data,
len(self._mask),
False, # this differs from the kernel call in 'bytemask'
self._lsb_order,
)
)
return ByteMaskedArray(
bytemask[: self._length],
self._content,
self._valid_when,
self._identifier,
self._parameters,
)

def _getitem_at(self, where):
if where < 0:
where += len(self)
if 0 > where or where >= len(self):
raise IndexError("array index out of bounds")
raise ak._v2.contents.content.NestedIndexError(self, where)
if self._lsb_order:
bit = bool(self._mask[where // 8] & (1 << (where % 8)))
else:
Expand All @@ -135,22 +177,7 @@ def _getitem_at(self, where):
return None

def _getitem_range(self, where):
# In general, slices must convert BitMaskedArray to ByteMaskedArray.
if self._lsb_order:
bytemask = (
self._mask.nplike.unpackbits(self._mask)
.reshape(-1, 8)[:, ::-1]
.reshape(-1)
.view(np.int8)
)
else:
bytemask = self._mask.nplike.unpackbits(self._mask).view(np.int8)
start, stop, step = where.indices(len(self))
return ByteMaskedArray(
Index(bytemask[start:stop]),
self._content[start:stop],
valid_when=self._valid_when,
)
return self.toByteMaskedArray()._getitem_range(where)

def _getitem_field(self, where):
return BitMaskedArray(
Expand All @@ -169,3 +196,6 @@ def _getitem_fields(self, where):
length=self._length,
lsb_order=self._lsb_order,
)

def _getitem_array(self, where, allow_lazy):
return self.toByteMaskedArray()._getitem_array(where, allow_lazy)
13 changes: 10 additions & 3 deletions src/awkward/_v2/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _getitem_at(self, where):
if where < 0:
where += len(self)
if 0 > where or where >= len(self):
raise IndexError("array index out of bounds")
raise ak._v2.contents.content.NestedIndexError(self, where)
if self._mask[where] == self._valid_when:
return self._content[where]
else:
Expand All @@ -96,8 +96,8 @@ def _getitem_at(self, where):
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
return ByteMaskedArray(
Index(self._mask[start:stop]),
self._content[start:stop],
self._mask[start:stop],
self._content._getitem_range(slice(start, stop)),
valid_when=self._valid_when,
)

Expand All @@ -110,3 +110,10 @@ def _getitem_fields(self, where):
return ByteMaskedArray(
self._mask, self._content[where], valid_when=self._valid_when
)

def _getitem_array(self, where, allow_lazy):
return ByteMaskedArray(
self._mask[where],
self._content._getitem_array(where, allow_lazy),
valid_when=self._valid_when,
)
190 changes: 143 additions & 47 deletions src/awkward/_v2/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import Iterable

import awkward as ak
from awkward._v2.tmp_for_testing import v1_to_v2, v2_to_v1

np = ak.nplike.NumpyMetadata.instance()

Expand Down Expand Up @@ -44,69 +45,164 @@ def parameters(self):
return self._parameters

def handle_error(self, error):
if error.filename is None:
filename = ""
else:
filename = error.filename.decode(errors="surrogateescape")
if error.str is not None:
if error.filename is None:
filename = ""
else:
filename = error.filename.decode(errors="surrogateescape")

message = error.str.decode(errors="surrogateescape")
message = error.str.decode(errors="surrogateescape")

if error.pass_through:
raise ValueError(message + filename)
if error.pass_through:
raise ValueError(message + filename)

else:
if error.id != ak._util.kSliceNone and self._identifier is not None:
# FIXME https://github.com/scikit-hep/awkward-1.0/blob/45d59ef4ae45eebb02995b8e1acaac0d46fb9573/src/libawkward/util.cpp#L443-L450
pass
else:
if error.id != ak._util.kSliceNone and self._identifier is not None:
# FIXME https://github.com/scikit-hep/awkward-1.0/blob/45d59ef4ae45eebb02995b8e1acaac0d46fb9573/src/libawkward/util.cpp#L443-L450
pass

if error.attempt != ak._util.kSliceNone:
message += " (attempting to get {0})".format(error.attempt)
if error.attempt != ak._util.kSliceNone:
message += " (attempting to get {0})".format(error.attempt)

# FIXME: attempt to pretty-print the array at this point; fall back to class name
pretty = " in " + type(self).__name__
# FIXME: attempt to pretty-print the array at this point; fall back to class name
pretty = " in " + type(self).__name__

message += pretty
raise ValueError(message + filename)
message += pretty
raise ValueError(message + filename)

def __getitem__(self, where):
if ak._util.isint(where):
return self._getitem_at(where)
import awkward._v2.contents.numpyarray

try:
if ak._util.isint(where):
return self._getitem_at(where)

elif isinstance(where, slice) and where.step is None:
return self._getitem_range(where)

elif isinstance(where, slice):
raise NotImplementedError("needs _getitem_next")

elif ak._util.isstr(where):
return self._getitem_field(where)

elif where is np.newaxis:
raise NotImplementedError("needs _getitem_next")

elif where is Ellipsis:
raise NotImplementedError("needs _getitem_next")

elif isinstance(where, tuple):
raise NotImplementedError("needs _getitem_next")

elif isinstance(where, ak.highlevel.Array):
raise NotImplementedError("needs _getitem_next")

elif isinstance(where, awkward._v2.contents.numpyarray.NumpyArray):
nplike = where.nplike
if issubclass(where.data.dtype.type, np.integer):
carry = where.data
elif issubclass(where.data.dtype.type, (np.bool_, np.bool)):
(carry,) = nplike.nonzero(where.data)
else:
raise TypeError(
"one-dimensional array slice must be an array of integers or "
"booleans, not\n\n {0}".format(
repr(where.data).replace("\n", "\n ")
)
)
asrange = self._getitem_asrange(carry, nplike)
if asrange is not None:
return asrange
else:
return self._getitem_array(carry, allow_lazy=True)

elif isinstance(where, Content):
raise NotImplementedError("needs _getitem_next")

elif isinstance(where, Iterable) and all(isinstance(x, str) for x in where):
return self._getitem_fields(where)

elif isinstance(where, Iterable):
return self.__getitem__(
v1_to_v2(ak.operations.convert.to_layout(where))
)

else:
raise TypeError(
"only integers, slices (`:`), ellipsis (`...`), np.newaxis (`None`), "
"integer/boolean arrays (possibly with variable-length nested "
"lists or missing values), field name (str) or names (non-tuple "
"iterable of str) are valid indices for slicing, not\n\n "
+ repr(where)
)

except NestedIndexError as err:
raise IndexError(
"""cannot slice

{0}

with

{1}

elif isinstance(where, slice) and where.step is None:
return self._getitem_range(where)
because an index is out of bounds (in {2} of length {3} using sub-slice {4})""".format(
repr(ak.Array(v2_to_v1(self))),
repr(where),
type(err.array).__name__,
len(err.array),
repr(err.slicer),
)
)

elif isinstance(where, slice):
raise NotImplementedError("needs _getitem_next")
def _getitem_asrange(self, where, nplike):
result = nplike.empty(1, dtype=np.bool_)
self.handle_error(
nplike[
"awkward_Index_iscontiguous", # badly named
np.bool_,
where.dtype.type,
](
result,
where,
len(where),
)
)
if result[0]:
if len(where) == len(self):
return self
elif len(where) < len(self):
return self._getitem_range(slice(0, len(where)))
else:
raise IndexError
else:
return None

elif ak._util.isstr(where):
return self._getitem_field(where)

elif where is np.newaxis:
raise NotImplementedError("needs _getitem_next")
class NestedIndexError(IndexError):
def __init__(self, array, slicer):
self._array = array
self._slicer = slicer

elif where is Ellipsis:
raise NotImplementedError("needs _getitem_next")
@property
def array(self):
return self._array

elif isinstance(where, tuple):
raise NotImplementedError("needs _getitem_next")
@property
def slicer(self):
return self._slicer

elif isinstance(where, ak.highlevel.Array):
raise NotImplementedError("needs _getitem_next")
def __str__(self):
return """cannot slice

elif isinstance(where, Content):
raise NotImplementedError("needs _getitem_next")
{0}

elif isinstance(where, Iterable) and all(ak._util.isstr(x) for x in where):
return self._getitem_fields(where)
with

elif isinstance(where, Iterable):
raise NotImplementedError("needs _getitem_next")
{1}

else:
raise TypeError(
"only integers, slices (`:`), ellipsis (`...`), np.newaxis (`None`), "
"integer/boolean arrays (possibly with variable-length nested "
"lists or missing values), field name (str) or names (non-tuple "
"iterable of str) are valid indices for slicing, not\n\n "
+ repr(where)
)
because an index is out of bounds""".format(
repr(self._array),
repr(self._slicer),
)
12 changes: 10 additions & 2 deletions src/awkward/_v2/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import awkward as ak
from awkward._v2.contents.content import Content

np = ak.nplike.NumpyMetadata.instance()


class EmptyArray(Content):
def __init__(self, identifier=None, parameters=None):
Expand All @@ -28,13 +30,19 @@ def __len__(self):
return 0

def _getitem_at(self, where):
raise IndexError("array of type Empty has no index " + repr(where))
raise ak._v2.contents.content.NestedIndexError(self, where)

def _getitem_range(self, where):
return EmptyArray()
return self

def _getitem_field(self, where):
raise IndexError("field " + repr(where) + " not found")

def _getitem_fields(self, where):
raise IndexError("fields " + repr(where) + " not found")

def _getitem_array(self, where, allow_lazy):
if len(where) == 0:
return self
else:
raise ak._v2.contents.content.NestedIndexError(self, where)
5 changes: 4 additions & 1 deletion src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _getitem_at(self, where):
if where < 0:
where += len(self)
if 0 > where or where >= len(self):
raise IndexError("array index out of bounds")
raise ak._v2.contents.content.NestedIndexError(self, where)
return self._content[self._index[where]]

def _getitem_range(self, where):
Expand All @@ -85,3 +85,6 @@ def _getitem_field(self, where):

def _getitem_fields(self, where):
return IndexedArray(self._index, self._content[where])

def _getitem_array(self, where, allow_lazy):
return IndexedArray(self._index[where], self._content)
5 changes: 4 additions & 1 deletion src/awkward/_v2/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _getitem_at(self, where):
if where < 0:
where += len(self)
if 0 > where or where >= len(self):
raise IndexError("array index out of bounds")
raise ak._v2.contents.content.NestedIndexError(self, where)
if self._index[where] < 0:
return None
else:
Expand All @@ -89,3 +89,6 @@ def _getitem_field(self, where):

def _getitem_fields(self, where):
return IndexedOptionArray(self._index, self._content[where])

def _getitem_array(self, where, allow_lazy):
return IndexedOptionArray(self._index[where], self._content)
Loading