From b41d62a71e164951a72d5173d7568ae4890e9887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Silv=C3=A9rio?= <29920212+HGSilveri@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:54:46 +0100 Subject: [PATCH] JSON serialization support for numpy integer types (#617) * Ignore __venv__ * Fix serialization compatiblity for numpy types * Update pre-commit config * Ignore .mypy_cache --- .flake8 | 2 +- .gitignore | 2 ++ .pre-commit-config.yaml | 4 ++-- .../pulser/json/abstract_repr/serializer.py | 4 +++- pulser-core/pulser/json/coders.py | 4 +++- tests/test_abstract_repr.py | 20 +++++++++++++++++++ tests/test_json.py | 6 ++++++ 7 files changed, 37 insertions(+), 5 deletions(-) diff --git a/.flake8 b/.flake8 index dffd549d..e5b0a13b 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] docstring-convention = google -exclude = ./build, ./docs +exclude = ./build, ./docs, ./__venv__ extend-ignore = # D105 Missing docstring in magic method D105, diff --git a/.gitignore b/.gitignore index 3225aac8..63d2ad3b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .vscode .python-version .pytest_cache/ +.mypy_cache/ .idea/ .coverage .spyproject/ @@ -15,3 +16,4 @@ docs/build/ dist/ env* *.egg-info/ +__venv__/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0afee19c..6e61c31a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 23.10.1 hooks: - id: black-jupyter @@ -10,7 +10,7 @@ repos: - id: flake8 - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/pulser-core/pulser/json/abstract_repr/serializer.py b/pulser-core/pulser/json/abstract_repr/serializer.py index 07038dc0..0bc2e8ac 100644 --- a/pulser-core/pulser/json/abstract_repr/serializer.py +++ b/pulser-core/pulser/json/abstract_repr/serializer.py @@ -38,12 +38,14 @@ class AbstractReprEncoder(json.JSONEncoder): """The custom encoder for abstract representation of Pulser objects.""" - def default(self, o: Any) -> Union[dict[str, Any], list[Any]]: + def default(self, o: Any) -> dict[str, Any] | list | int: """Handles JSON encoding of objects not supported by default.""" if hasattr(o, "_to_abstract_repr"): return cast(dict, o._to_abstract_repr()) elif isinstance(o, np.ndarray): return cast(list, o.tolist()) + elif isinstance(o, np.integer): + return int(o) elif isinstance(o, set): return list(o) else: diff --git a/pulser-core/pulser/json/coders.py b/pulser-core/pulser/json/coders.py index 989996f8..8dfa6a96 100644 --- a/pulser-core/pulser/json/coders.py +++ b/pulser-core/pulser/json/coders.py @@ -30,7 +30,7 @@ class PulserEncoder(JSONEncoder): """The custom encoder for Pulser objects.""" - def default(self, o: Any) -> dict[str, Any]: + def default(self, o: Any) -> dict[str, Any] | int: """Handles JSON encoding of objects not supported by default.""" if hasattr(o, "_to_dict"): return cast(dict, o._to_dict()) @@ -38,6 +38,8 @@ def default(self, o: Any) -> dict[str, Any]: return obj_to_dict(o, _build=False, _name=o.__name__) elif isinstance(o, np.ndarray): return obj_to_dict(o, o.tolist(), _name="array") + elif isinstance(o, np.integer): + return int(o) elif isinstance(o, set): return obj_to_dict(o, list(o)) else: diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index d85d85dd..9dab8064 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -939,6 +939,26 @@ def test_multi_qubit_target(self): "rhs": 2, } + def test_numpy_types(self): + assert ( + json.loads( + json.dumps(np.array([12345])[0], cls=AbstractReprEncoder) + ) + == 12345 + ) + assert ( + json.loads( + json.dumps(np.array([np.pi])[0], cls=AbstractReprEncoder) + ) + == np.pi + ) + assert ( + json.loads( + json.dumps(np.array(["abc"])[0], cls=AbstractReprEncoder) + ) + == "abc" + ) + def _get_serialized_seq( operations: list[dict] = [], diff --git a/tests/test_json.py b/tests/test_json.py index 661021fb..566b274d 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -243,6 +243,12 @@ def test_type_error(): Sequence._deserialize(json.loads(s)) +def test_numpy_types(): + assert encode_decode(np.array([12])[0]) == 12 + assert encode_decode(np.array([np.pi])[0]) == np.pi + assert encode_decode(np.array(["abc"])[0]) == "abc" + + def test_deprecated_device_args(): seq = Sequence(Register.square(1), MockDevice)