Skip to content

Commit

Permalink
Merge pull request #46 from rhayes777/feature/array_impl_dict
Browse files Browse the repository at this point in the history
handle converting JAX ArrayImpl to dict
  • Loading branch information
Jammy2211 authored Feb 1, 2024
2 parents 833821c + 74c7b01 commit c4a3fcd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
14 changes: 13 additions & 1 deletion autoconf/dictable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def nd_array_from_dict(nd_array_dict: dict, **_) -> np.ndarray:
return np.array(nd_array_dict["array"], dtype=getattr(np, nd_array_dict["dtype"]))


def is_array(obj) -> bool:
"""
True if the object is a numpy array or an ArrayImpl (i.e. from JAX)
"""
if isinstance(obj, np.ndarray):
return True
try:
return obj.__class__.__name__ == "ArrayImpl"
except AttributeError:
return False


def to_dict(obj):
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
Expand All @@ -45,7 +57,7 @@ def to_dict(obj):
except TypeError as e:
logger.debug(e)

if isinstance(obj, np.ndarray):
if is_array(obj):
try:
return nd_array_as_dict(obj)
except Exception as e:
Expand Down
19 changes: 19 additions & 0 deletions test_autoconf/test_dictable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ def make_array():
return np.array([1.0])


class ArrayImpl:
def __init__(self, array):
self.array = array

@property
def dtype(self):
return self.array.dtype

def tolist(self):
return self.array.tolist()

def __array__(self):
return self.array


def test_array_impl(array):
assert to_dict(ArrayImpl(array)) == to_dict(array)


def test_array_as_dict(array_dict, array):
assert to_dict(array) == array_dict

Expand Down

0 comments on commit c4a3fcd

Please sign in to comment.