Skip to content

Commit

Permalink
refactor: ak._util (#1848)
Browse files Browse the repository at this point in the history
* refactor: simplify `isint`, remove `isnum`

* refactor: rename `isint` to `is_integer`

* refactor: remove `isstr`

* chore: correct pylint complaints
  • Loading branch information
agoose77 authored Oct 31, 2022
1 parent fc25886 commit fec5eee
Show file tree
Hide file tree
Showing 43 changed files with 135 additions and 143 deletions.
8 changes: 4 additions & 4 deletions src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def prepare_advanced_indexing(items):
list, # of strings
ak.contents.ListOffsetArray,
ak.contents.IndexedOptionArray,
str,
),
)
or ak._util.isstr(item)
or item is np.newaxis
or item is Ellipsis
):
Expand Down Expand Up @@ -126,13 +126,13 @@ def prepare_advanced_indexing(items):


def normalise_item(item, nplike):
if ak._util.isint(item):
if ak._util.is_integer(item):
return int(item)

elif isinstance(item, slice):
return item

elif ak._util.isstr(item):
elif isinstance(item, str):
return item

elif item is np.newaxis:
Expand Down Expand Up @@ -160,7 +160,7 @@ def normalise_item(item, nplike):
elif ak._util.is_sized_iterable(item) and len(item) == 0:
return nplike.empty(0, dtype=np.int64)

elif ak._util.is_sized_iterable(item) and all(ak._util.isstr(x) for x in item):
elif ak._util.is_sized_iterable(item) and all(isinstance(x, str) for x in item):
return list(item)

elif ak._util.is_sized_iterable(item):
Expand Down
18 changes: 11 additions & 7 deletions src/awkward/_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def shape(self):

@shape.setter
def shape(self, value):
if ak._util.isint(value):
if ak._util.is_integer(value):
value = (value,)
elif value is None or isinstance(value, (UnknownLengthType, UnknownScalar)):
value = (UnknownLength,)
Expand Down Expand Up @@ -347,7 +347,7 @@ def __getitem__(self, where):
missing = max(0, len(self._shape) - (len(before) + len(after)))
where = before + (slice(None, None, None),) * missing + after

if ak._util.isint(where):
if ak._util.is_integer(where):
if len(self._shape) == 1:
if where == 0:
return UnknownScalar(self._dtype)
Expand Down Expand Up @@ -391,7 +391,7 @@ def __getitem__(self, where):
shapes = []
for j in range(num_basic, len(where)):
wh = where[j]
if ak._util.isint(wh):
if ak._util.is_integer(wh):
shapes.append(numpy.array(0))
elif hasattr(wh, "dtype") and hasattr(wh, "shape"):
sh = [
Expand All @@ -416,7 +416,7 @@ def __getitem__(self, where):
elif (
isinstance(where, tuple)
and len(where) > 0
and (ak._util.isint(where[0]) or isinstance(where[0], slice))
and (ak._util.is_integer(where[0]) or isinstance(where[0], slice))
):
head, tail = where[0], where[1:]
next = self.__getitem__(head)
Expand Down Expand Up @@ -466,8 +466,8 @@ def reshape(self, *args):
args = args[0]

assert len(args) != 0
assert ak._util.isint(args[0]) or isinstance(args[0], UnknownLengthType)
assert all(ak._util.isint(x) for x in args[1:])
assert ak._util.is_integer(args[0]) or isinstance(args[0], UnknownLengthType)
assert all(ak._util.is_integer(x) for x in args[1:])
assert all(x >= 0 for x in args[1:])

return TypeTracerArray(self._dtype, (UnknownLength,) + args[1:])
Expand Down Expand Up @@ -587,7 +587,11 @@ def arange(self, *args, **kwargs):
elif len(args) == 3:
start, stop, step = args[0], args[1], args[2]

if ak._util.isint(start) and ak._util.isint(stop) and ak._util.isint(step):
if (
ak._util.is_integer(start)
and ak._util.is_integer(stop)
and ak._util.is_integer(step)
):
length = max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)

return TypeTracerArray(kwargs["dtype"], (length,))
Expand Down
60 changes: 24 additions & 36 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,8 @@ def is_sized_iterable(obj):
return isinstance(obj, Iterable) and isinstance(obj, Sized)


def isint(x):
return isinstance(x, (int, numbers.Integral, np.integer)) and not isinstance(
x, (bool, np.bool_)
)


def isnum(x):
return isinstance(x, (int, float, numbers.Real, np.number)) and not isinstance(
x, (bool, np.bool_)
)


def isstr(x):
return isinstance(x, str)
def is_integer(x):
return isinstance(x, numbers.Integral) and not isinstance(x, bool)


def tobytes(array):
Expand Down Expand Up @@ -153,12 +141,12 @@ def overlay_behavior(behavior: dict | None) -> collections.abc.Mapping:
def arrayclass(layout, behavior):
behavior = overlay_behavior(behavior)
arr = layout.parameter("__array__")
if isstr(arr):
if isinstance(arr, str):
cls = behavior.get(arr)
if isinstance(cls, type) and issubclass(cls, ak.highlevel.Array):
return cls
deeprec = layout.purelist_parameter("__record__")
if isstr(deeprec):
if isinstance(deeprec, str):
cls = behavior.get(("*", deeprec))
if isinstance(cls, type) and issubclass(cls, ak.highlevel.Array):
return cls
Expand All @@ -181,11 +169,11 @@ def custom_cast(obj, behavior):
def custom_broadcast(layout, behavior):
behavior = overlay_behavior(behavior)
custom = layout.parameter("__array__")
if not isstr(custom):
if not isinstance(custom, str):
custom = layout.parameter("__record__")
if not isstr(custom):
if not isinstance(custom, str):
custom = layout.purelist_parameter("__record__")
if isstr(custom):
if isinstance(custom, str):
for key, fcn in behavior.items():
if (
isinstance(key, tuple)
Expand All @@ -202,9 +190,9 @@ def custom_ufunc(ufunc, layout, behavior):

behavior = overlay_behavior(behavior)
custom = layout.parameter("__array__")
if not isstr(custom):
if not isinstance(custom, str):
custom = layout.parameter("__record__")
if isstr(custom):
if isinstance(custom, str):
for key, fcn in behavior.items():
if (
isinstance(key, tuple)
Expand All @@ -219,12 +207,12 @@ def custom_ufunc(ufunc, layout, behavior):
def numba_array_typer(layouttype, behavior):
behavior = overlay_behavior(behavior)
arr = layouttype.parameters.get("__array__")
if isstr(arr):
if isinstance(arr, str):
typer = behavior.get(("__numba_typer__", arr))
if callable(typer):
return typer
deeprec = layouttype.parameters.get("__record__")
if isstr(deeprec):
if isinstance(deeprec, str):
typer = behavior.get(("__numba_typer__", "*", deeprec))
if callable(typer):
return typer
Expand All @@ -234,12 +222,12 @@ def numba_array_typer(layouttype, behavior):
def numba_array_lower(layouttype, behavior):
behavior = overlay_behavior(behavior)
arr = layouttype.parameters.get("__array__")
if isstr(arr):
if isinstance(arr, str):
lower = behavior.get(("__numba_lower__", arr))
if callable(lower):
return lower
deeprec = layouttype.parameters.get("__record__")
if isstr(deeprec):
if isinstance(deeprec, str):
lower = behavior.get(("__numba_lower__", "*", deeprec))
if callable(lower):
return lower
Expand All @@ -249,7 +237,7 @@ def numba_array_lower(layouttype, behavior):
def recordclass(layout, behavior):
behavior = overlay_behavior(behavior)
rec = layout.parameter("__record__")
if isstr(rec):
if isinstance(rec, str):
cls = behavior.get(rec)
if isinstance(cls, type) and issubclass(cls, ak.highlevel.Record):
return cls
Expand All @@ -259,7 +247,7 @@ def recordclass(layout, behavior):
def reducer_recordclass(reducer, layout, behavior):
behavior = overlay_behavior(behavior)
rec = layout.parameter("__record__")
if isstr(rec):
if isinstance(rec, str):
return behavior.get((reducer.highlevel_function(), rec))


Expand All @@ -271,8 +259,8 @@ def typestrs(behavior):
isinstance(key, tuple)
and len(key) == 2
and key[0] == "__typestr__"
and isstr(key[1])
and isstr(typestr)
and isinstance(key[1], str)
and isinstance(typestr, str)
):
out[key[1]] = typestr
return out
Expand All @@ -296,7 +284,7 @@ def gettypestr(parameters, typestrs):
def numba_record_typer(layouttype, behavior):
behavior = overlay_behavior(behavior)
rec = layouttype.parameters.get("__record__")
if isstr(rec):
if isinstance(rec, str):
typer = behavior.get(("__numba_typer__", rec))
if callable(typer):
return typer
Expand All @@ -306,7 +294,7 @@ def numba_record_typer(layouttype, behavior):
def numba_record_lower(layouttype, behavior):
behavior = overlay_behavior(behavior)
rec = layouttype.parameters.get("__record__")
if isstr(rec):
if isinstance(rec, str):
lower = behavior.get(("__numba_lower__", rec))
if callable(lower):
return lower
Expand Down Expand Up @@ -335,7 +323,7 @@ def overload(behavior, signature):
def numba_attrs(layouttype, behavior):
behavior = overlay_behavior(behavior)
rec = layouttype.parameters.get("__record__")
if isstr(rec):
if isinstance(rec, str):
for key, typer in behavior.items():
if (
isinstance(key, tuple)
Expand All @@ -350,7 +338,7 @@ def numba_attrs(layouttype, behavior):
def numba_methods(layouttype, behavior):
behavior = overlay_behavior(behavior)
rec = layouttype.parameters.get("__record__")
if isstr(rec):
if isinstance(rec, str):
for key, typer in behavior.items():
if (
isinstance(key, tuple)
Expand All @@ -369,7 +357,7 @@ def numba_unaryops(unaryop, left, behavior):

if isinstance(left, ak._connect.numba.layout.ContentType):
left = left.parameters.get("__record__")
if not isstr(left):
if not isinstance(left, str):
done = True

if not done:
Expand All @@ -391,12 +379,12 @@ def numba_binops(binop, left, right, behavior):

if isinstance(left, ak._connect.numba.layout.ContentType):
left = left.parameters.get("__record__")
if not isstr(left):
if not isinstance(left, str):
done = True

if isinstance(right, ak._connect.numba.layout.ContentType):
right = right.parameters.get("__record__")
if not isstr(right):
if not isinstance(right, str):
done = True

if not done:
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
)
)
if not isinstance(length, ak._typetracer.UnknownLengthType):
if not (ak._util.isint(length) and length >= 0):
if not (ak._util.is_integer(length) and length >= 0):
raise ak._errors.wrap_error(
TypeError(
"{} 'length' must be a non-negative integer, not {}".format(
Expand Down Expand Up @@ -388,7 +388,7 @@ def _getitem_next(self, head, tail, advanced):
):
return self.toByteMaskedArray()._getitem_next(head, tail, advanced)

elif ak._util.isstr(head):
elif isinstance(head, str):
return self._getitem_next_field(head, tail, advanced)

elif isinstance(head, list):
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def _getitem_next(self, head, tail, advanced):
)
return out2.simplify_optiontype()

elif ak._util.isstr(head):
elif isinstance(head, str):
return self._getitem_next_field(head, tail, advanced)

elif isinstance(head, list):
Expand Down Expand Up @@ -541,7 +541,7 @@ def num(self, axis, depth=0):
posaxis = self.axis_wrap_if_negative(axis)
if posaxis == depth:
out = self.length
if ak._util.isint(out):
if ak._util.is_integer(out):
return np.int64(out)
else:
return out
Expand Down
16 changes: 8 additions & 8 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def form_with_key(self, form_key="node{id}", id_start=0):
def getkey(layout):
return None

elif ak._util.isstr(form_key):
elif isinstance(form_key, str):

def getkey(layout):
out = form_key.format(id=hold_id[0])
Expand Down Expand Up @@ -170,7 +170,7 @@ def to_buffers(
TypeError("cannot call 'to_buffers' on an array without concrete data")
)

if ak._util.isstr(buffer_key):
if isinstance(buffer_key, str):

def getkey(layout, form, attribute):
return buffer_key.format(form_key=form.form_key, attribute=attribute)
Expand Down Expand Up @@ -318,7 +318,7 @@ def _getitem_next_field(self, head, tail, advanced: ak.index.Index | None):
def _getitem_next_fields(self, head, tail, advanced: ak.index.Index | None):
only_fields, not_fields = [], []
for x in tail:
if ak._util.isstr(x) or isinstance(x, list):
if isinstance(x, (str, list)):
only_fields.append(x)
else:
not_fields.append(x)
Expand Down Expand Up @@ -527,7 +527,7 @@ def __getitem__(self, where):
return self._getitem(where)

def _getitem(self, where):
if ak._util.isint(where):
if ak._util.is_integer(where):
return self._getitem_at(where)

elif isinstance(where, slice) and where.step is None:
Expand All @@ -536,7 +536,7 @@ def _getitem(self, where):
elif isinstance(where, slice):
return self._getitem((where,))

elif ak._util.isstr(where):
elif isinstance(where, str):
return self._getitem_field(where)

elif where is np.newaxis:
Expand Down Expand Up @@ -626,7 +626,7 @@ def _getitem(self, where):
return self._carry(ak.index.Index64.empty(0, self._nplike), allow_lazy=True)

elif ak._util.is_sized_iterable(where) and all(
ak._util.isstr(x) for x in where
isinstance(x, str) for x in where
):
return self._getitem_fields(where)

Expand Down Expand Up @@ -1627,8 +1627,8 @@ def to_json(
isinstance(complex_record_fields, Sized)
and isinstance(complex_record_fields, Iterable)
and len(complex_record_fields) == 2
and ak._util.isstr(complex_record_fields[0])
and ak._util.isstr(complex_record_fields[1])
and isinstance(complex_record_fields[0], str)
and isinstance(complex_record_fields[1], str)
):
complex_real_string, complex_imag_string = complex_record_fields
else:
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _getitem_next(self, head, tail, advanced):
elif isinstance(head, slice):
raise ak._errors.index_error(self, head, "array is empty")

elif ak._util.isstr(head):
elif isinstance(head, str):
return self._getitem_next_field(head, tail, advanced)

elif isinstance(head, list):
Expand Down Expand Up @@ -163,7 +163,7 @@ def num(self, axis, depth=0):

if posaxis == depth:
out = self.length
if ak._util.isint(out):
if ak._util.is_integer(out):
return np.int64(out)
else:
return out
Expand Down
Loading

0 comments on commit fec5eee

Please sign in to comment.