Skip to content

Commit

Permalink
add helper function make_serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Aug 31, 2023
1 parent 2dc486a commit e54b9cb
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 7 deletions.
2 changes: 1 addition & 1 deletion test/test_models/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_wbce():

criterion_wbce = WeightedBCELoss(torch.ones((1, 2)), reduce=False).to(DEVICE)
criterion_wbce(torch.sigmoid(inp), targ_1)
criterion_wbce = WeightedBCELoss(torch.ones((1, 2)), size_average=False)
criterion_wbce = WeightedBCELoss(torch.ones((1, 2)), size_average=False).to(DEVICE)
criterion_wbce(torch.sigmoid(inp), targ_1)

with pytest.raises(
Expand Down
14 changes: 14 additions & 0 deletions test/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_kwargs,
get_required_args,
add_kwargs,
make_serializable,
)
from torch_ecg.cfg import DEFAULTS, _DATA_CACHE
from torch_ecg.utils.download import http_get
Expand Down Expand Up @@ -588,3 +589,16 @@ def func(self, a, b=1):

assert new_func(2) == new_func(2, xxx="a", zzz=100) == 3
assert get_kwargs(new_func) == {"b": 1, "xxx": "yyy", "zzz": None}


def test_make_serializable():
x = np.array([1, 2, 3])
assert make_serializable(x) == [1, 2, 3]
x = {"a": np.array([1, 2, 3]), "b": [np.array([4, 5, 6]), np.array([7, 8, 9])]}
assert make_serializable(x) == {"a": [1, 2, 3], "b": [[4, 5, 6], [7, 8, 9]]}
x = [np.array([1, 2, 3]), np.array([4, 5, 6])]
assert make_serializable(x) == [[1, 2, 3], [4, 5, 6]]
x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean())
obj = make_serializable(x)
assert obj == [[1, 2, 3], 5.0]
assert isinstance(obj[1], float) and isinstance(x[1], np.float64)
2 changes: 2 additions & 0 deletions torch_ecg/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
get_kwargs,
get_required_args,
add_kwargs,
make_serializable,
)
from .utils_data import (
get_mask,
Expand Down Expand Up @@ -262,6 +263,7 @@
"get_kwargs",
"get_required_args",
"add_kwargs",
"make_serializable",
"get_mask",
"class_weight_to_sample_weight",
"ensure_lead_fmt",
Expand Down
60 changes: 58 additions & 2 deletions torch_ecg/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"get_kwargs",
"get_required_args",
"add_kwargs",
"make_serializable",
]


Expand Down Expand Up @@ -218,8 +219,8 @@ def get_record_list_recursive3(
def dict_to_str(
d: Union[dict, list, tuple], current_depth: int = 1, indent_spaces: int = 4
) -> str:
"""
Convert a (possibly) nested dict into a `str` of json-like formatted form.
"""Convert a (possibly) nested dict into a `str` of json-like formatted form.
This nested dict might also contain lists or tuples of dict (and of str, int, etc.)
Parameters
Expand Down Expand Up @@ -1540,3 +1541,58 @@ def wrapper(*args: Any, **kwargs_: Any) -> Any:
return func(*args, **filtered_kwargs)

return wrapper


def make_serializable(
x: Union[np.ndarray, np.generic, dict, list, tuple]
) -> Union[list, dict, Number]:
"""Make an object serializable.
This function is used to convert all numpy arrays to list in an object,
and also convert numpy data types to python data types in the object,
so that it can be serialized by :mod:`json`.
Parameters
----------
x : Union[numpy.ndarray, numpy.generic, dict, list, tuple]
Input data, which can be numpy array (or numpy data type),
or dict, list, tuple containing numpy arrays (or numpy data type).
Returns
-------
Union[list, dict, numbers.Number]
Converted data.
Examples
--------
>>> import numpy as np
>>> from fl_sim.utils.misc import make_serializable
>>> x = np.array([1, 2, 3])
>>> make_serializable(x)
[1, 2, 3]
>>> x = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
>>> make_serializable(x)
{'a': [1, 2, 3], 'b': [4, 5, 6]}
>>> x = [np.array([1, 2, 3]), np.array([4, 5, 6])]
>>> make_serializable(x)
[[1, 2, 3], [4, 5, 6]]
>>> x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean())
>>> obj = make_serializable(x)
>>> obj
[[1, 2, 3], 5.0]
>>> type(obj[1]), type(x[1])
(float, numpy.float64)
"""
if isinstance(x, np.ndarray):
return x.tolist()
elif isinstance(x, (list, tuple)):
# to avoid cases where the list contains numpy data types
return [make_serializable(v) for v in x]
elif isinstance(x, dict):
for k, v in x.items():
x[k] = make_serializable(v)
elif isinstance(x, np.generic):
return x.item()
# the other types will be returned directly
return x
6 changes: 2 additions & 4 deletions torch_ecg/utils/utils_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor, nn

from ..cfg import CFG, DEFAULTS
from .misc import add_docstring
from .misc import add_docstring, make_serializable
from .utils_data import cls_to_bin


Expand Down Expand Up @@ -853,9 +853,7 @@ def compute_receptive_field(
receptive_field = min(receptive_field, input_len)
if fs is not None:
receptive_field /= fs
if isinstance(receptive_field, np.generic):
receptive_field = receptive_field.item()
return receptive_field
return make_serializable(receptive_field)


def default_collate_fn(
Expand Down

0 comments on commit e54b9cb

Please sign in to comment.