From fb816422b9f1cf97ebd1e698b073eb674c026708 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sun, 9 Apr 2023 20:26:13 -0500 Subject: [PATCH] Custom dataclass decorator (#11) * create custom dataclass decorator * fix tests * add typing extension * ignore overloads from covarage --- README.md | 12 +- poetry.lock | 352 ++++++++++++++++++++----------------- pyproject.toml | 6 +- simple_pytree/__init__.py | 5 +- simple_pytree/dataclass.py | 103 +++++++++++ simple_pytree/pytree.py | 54 ------ tests/test_pytree.py | 29 ++- 7 files changed, 326 insertions(+), 235 deletions(-) create mode 100644 simple_pytree/dataclass.py diff --git a/README.md b/README.md index c84084f..667dd01 100644 --- a/README.md +++ b/README.md @@ -72,24 +72,26 @@ Static fields are not included in the pytree leaves, they are passed as pytree metadata instead. ### Dataclasses -You can seamlessly use the `dataclasses.dataclass` decorator with `Pytree` classes. -Since `static_field` returns instances of `dataclasses.Field` these it will work as expected: +`simple_pytree` provides a `dataclass` decorator you can use with classes +that contain `static_field`s: ```python import jax -from dataclasses import dataclass -from simple_pytree import Pytree, static_field +from simple_pytree import Pytree, dataclass, static_field @dataclass class Foo(Pytree): x: int - y: int = static_field(2) # with default value + y: int = static_field(default=2) foo = Foo(1) foo = jax.tree_map(lambda x: -x, foo) # y is not modified assert foo.x == -1 and foo.y == 2 ``` +`simple_pytree.dataclass` is just a wrapper around `dataclasses.dataclass` but +when used static analysis tools and IDEs will understand that `static_field` is a +field specifier just like `dataclasses.field`. ### Mutability `Pytree` objects are immutable by default after `__init__`: diff --git a/poetry.lock b/poetry.lock index 688e448..148dfe4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,25 +12,6 @@ files = [ {file = "absl_py-1.4.0-py3-none-any.whl", hash = "sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47"}, ] -[[package]] -name = "attrs" -version = "22.2.0" -description = "Classes Without Boilerplate" -category = "dev" -optional = false -python-versions = ">=3.6" -files = [ - {file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"}, - {file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"}, -] - -[package.extras] -cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"] -dev = ["attrs[docs,tests]"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"] -tests = ["attrs[tests-no-zope]", "zope.interface"] -tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"] - [[package]] name = "black" version = "23.1.0" @@ -107,20 +88,20 @@ files = [ [[package]] name = "chex" -version = "0.1.6" +version = "0.1.7" description = "Chex: Testing made fun, in JAX!" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "chex-0.1.6-py3-none-any.whl", hash = "sha256:5e04a28fcab3c2c0fb05431c6e646e8e4889cacb3e1670369e556f12029cf785"}, - {file = "chex-0.1.6.tar.gz", hash = "sha256:adb5d2352b5f0d248ccf594be1b1bf9ee7a2bee2a57f0eac78547538d479b0e7"}, + {file = "chex-0.1.7-py3-none-any.whl", hash = "sha256:9f583015303b1205443843c0b55849bb287f1dfdbd22d9907b1ebb04f964d93e"}, + {file = "chex-0.1.7.tar.gz", hash = "sha256:74ed49799ac4d229881456d468136f1b19a9f9839e3de72b058824e2a4f4dedd"}, ] [package.dependencies] absl-py = ">=0.9.0" dm-tree = ">=0.1.5" -jax = ">=0.1.55" +jax = ">=0.4.6" jaxlib = ">=0.1.37" numpy = ">=1.18.0" toolz = ">=0.9.0" @@ -155,63 +136,63 @@ files = [ [[package]] name = "coverage" -version = "7.2.2" +version = "7.2.3" description = "Code coverage measurement for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "coverage-7.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c90e73bdecb7b0d1cea65a08cb41e9d672ac6d7995603d6465ed4914b98b9ad7"}, - {file = "coverage-7.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e2926b8abedf750c2ecf5035c07515770944acf02e1c46ab08f6348d24c5f94d"}, - {file = "coverage-7.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57b77b9099f172804e695a40ebaa374f79e4fb8b92f3e167f66facbf92e8e7f5"}, - {file = "coverage-7.2.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:efe1c0adad110bf0ad7fb59f833880e489a61e39d699d37249bdf42f80590169"}, - {file = "coverage-7.2.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2199988e0bc8325d941b209f4fd1c6fa007024b1442c5576f1a32ca2e48941e6"}, - {file = "coverage-7.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:81f63e0fb74effd5be736cfe07d710307cc0a3ccb8f4741f7f053c057615a137"}, - {file = "coverage-7.2.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:186e0fc9cf497365036d51d4d2ab76113fb74f729bd25da0975daab2e107fd90"}, - {file = "coverage-7.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:420f94a35e3e00a2b43ad5740f935358e24478354ce41c99407cddd283be00d2"}, - {file = "coverage-7.2.2-cp310-cp310-win32.whl", hash = "sha256:38004671848b5745bb05d4d621526fca30cee164db42a1f185615f39dc997292"}, - {file = "coverage-7.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:0ce383d5f56d0729d2dd40e53fe3afeb8f2237244b0975e1427bfb2cf0d32bab"}, - {file = "coverage-7.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3eb55b7b26389dd4f8ae911ba9bc8c027411163839dea4c8b8be54c4ee9ae10b"}, - {file = "coverage-7.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d2b96123a453a2d7f3995ddb9f28d01fd112319a7a4d5ca99796a7ff43f02af5"}, - {file = "coverage-7.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:299bc75cb2a41e6741b5e470b8c9fb78d931edbd0cd009c58e5c84de57c06731"}, - {file = "coverage-7.2.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e1df45c23d4230e3d56d04414f9057eba501f78db60d4eeecfcb940501b08fd"}, - {file = "coverage-7.2.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:006ed5582e9cbc8115d2e22d6d2144a0725db542f654d9d4fda86793832f873d"}, - {file = "coverage-7.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d683d230b5774816e7d784d7ed8444f2a40e7a450e5720d58af593cb0b94a212"}, - {file = "coverage-7.2.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8efb48fa743d1c1a65ee8787b5b552681610f06c40a40b7ef94a5b517d885c54"}, - {file = "coverage-7.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4c752d5264053a7cf2fe81c9e14f8a4fb261370a7bb344c2a011836a96fb3f57"}, - {file = "coverage-7.2.2-cp311-cp311-win32.whl", hash = "sha256:55272f33da9a5d7cccd3774aeca7a01e500a614eaea2a77091e9be000ecd401d"}, - {file = "coverage-7.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:92ebc1619650409da324d001b3a36f14f63644c7f0a588e331f3b0f67491f512"}, - {file = "coverage-7.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5afdad4cc4cc199fdf3e18088812edcf8f4c5a3c8e6cb69127513ad4cb7471a9"}, - {file = "coverage-7.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0484d9dd1e6f481b24070c87561c8d7151bdd8b044c93ac99faafd01f695c78e"}, - {file = "coverage-7.2.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d530191aa9c66ab4f190be8ac8cc7cfd8f4f3217da379606f3dd4e3d83feba69"}, - {file = "coverage-7.2.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ac0f522c3b6109c4b764ffec71bf04ebc0523e926ca7cbe6c5ac88f84faced0"}, - {file = "coverage-7.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ba279aae162b20444881fc3ed4e4f934c1cf8620f3dab3b531480cf602c76b7f"}, - {file = "coverage-7.2.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:53d0fd4c17175aded9c633e319360d41a1f3c6e352ba94edcb0fa5167e2bad67"}, - {file = "coverage-7.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c99cb7c26a3039a8a4ee3ca1efdde471e61b4837108847fb7d5be7789ed8fd9"}, - {file = "coverage-7.2.2-cp37-cp37m-win32.whl", hash = "sha256:5cc0783844c84af2522e3a99b9b761a979a3ef10fb87fc4048d1ee174e18a7d8"}, - {file = "coverage-7.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:817295f06eacdc8623dc4df7d8b49cea65925030d4e1e2a7c7218380c0072c25"}, - {file = "coverage-7.2.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6146910231ece63facfc5984234ad1b06a36cecc9fd0c028e59ac7c9b18c38c6"}, - {file = "coverage-7.2.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:387fb46cb8e53ba7304d80aadca5dca84a2fbf6fe3faf6951d8cf2d46485d1e5"}, - {file = "coverage-7.2.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:046936ab032a2810dcaafd39cc4ef6dd295df1a7cbead08fe996d4765fca9fe4"}, - {file = "coverage-7.2.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e627dee428a176ffb13697a2c4318d3f60b2ccdde3acdc9b3f304206ec130ccd"}, - {file = "coverage-7.2.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fa54fb483decc45f94011898727802309a109d89446a3c76387d016057d2c84"}, - {file = "coverage-7.2.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3668291b50b69a0c1ef9f462c7df2c235da3c4073f49543b01e7eb1dee7dd540"}, - {file = "coverage-7.2.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7c20b731211261dc9739bbe080c579a1835b0c2d9b274e5fcd903c3a7821cf88"}, - {file = "coverage-7.2.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5764e1f7471cb8f64b8cda0554f3d4c4085ae4b417bfeab236799863703e5de2"}, - {file = "coverage-7.2.2-cp38-cp38-win32.whl", hash = "sha256:4f01911c010122f49a3e9bdc730eccc66f9b72bd410a3a9d3cb8448bb50d65d3"}, - {file = "coverage-7.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:c448b5c9e3df5448a362208b8d4b9ed85305528313fca1b479f14f9fe0d873b8"}, - {file = "coverage-7.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfe7085783cda55e53510482fa7b5efc761fad1abe4d653b32710eb548ebdd2d"}, - {file = "coverage-7.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9d22e94e6dc86de981b1b684b342bec5e331401599ce652900ec59db52940005"}, - {file = "coverage-7.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:507e4720791977934bba016101579b8c500fb21c5fa3cd4cf256477331ddd988"}, - {file = "coverage-7.2.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc4803779f0e4b06a2361f666e76f5c2e3715e8e379889d02251ec911befd149"}, - {file = "coverage-7.2.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db8c2c5ace167fd25ab5dd732714c51d4633f58bac21fb0ff63b0349f62755a8"}, - {file = "coverage-7.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4f68ee32d7c4164f1e2c8797535a6d0a3733355f5861e0f667e37df2d4b07140"}, - {file = "coverage-7.2.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d52f0a114b6a58305b11a5cdecd42b2e7f1ec77eb20e2b33969d702feafdd016"}, - {file = "coverage-7.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:797aad79e7b6182cb49c08cc5d2f7aa7b2128133b0926060d0a8889ac43843be"}, - {file = "coverage-7.2.2-cp39-cp39-win32.whl", hash = "sha256:db45eec1dfccdadb179b0f9ca616872c6f700d23945ecc8f21bb105d74b1c5fc"}, - {file = "coverage-7.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:8dbe2647bf58d2c5a6c5bcc685f23b5f371909a5624e9f5cd51436d6a9f6c6ef"}, - {file = "coverage-7.2.2-pp37.pp38.pp39-none-any.whl", hash = "sha256:872d6ce1f5be73f05bea4df498c140b9e7ee5418bfa2cc8204e7f9b817caa968"}, - {file = "coverage-7.2.2.tar.gz", hash = "sha256:36dd42da34fe94ed98c39887b86db9d06777b1c8f860520e21126a75507024f2"}, + {file = "coverage-7.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e58c0d41d336569d63d1b113bd573db8363bc4146f39444125b7f8060e4e04f5"}, + {file = "coverage-7.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:344e714bd0fe921fc72d97404ebbdbf9127bac0ca1ff66d7b79efc143cf7c0c4"}, + {file = "coverage-7.2.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:974bc90d6f6c1e59ceb1516ab00cf1cdfbb2e555795d49fa9571d611f449bcb2"}, + {file = "coverage-7.2.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0743b0035d4b0e32bc1df5de70fba3059662ace5b9a2a86a9f894cfe66569013"}, + {file = "coverage-7.2.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d0391fb4cfc171ce40437f67eb050a340fdbd0f9f49d6353a387f1b7f9dd4fa"}, + {file = "coverage-7.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4a42e1eff0ca9a7cb7dc9ecda41dfc7cbc17cb1d02117214be0561bd1134772b"}, + {file = "coverage-7.2.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:be19931a8dcbe6ab464f3339966856996b12a00f9fe53f346ab3be872d03e257"}, + {file = "coverage-7.2.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:72fcae5bcac3333a4cf3b8f34eec99cea1187acd55af723bcbd559adfdcb5535"}, + {file = "coverage-7.2.3-cp310-cp310-win32.whl", hash = "sha256:aeae2aa38395b18106e552833f2a50c27ea0000122bde421c31d11ed7e6f9c91"}, + {file = "coverage-7.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:83957d349838a636e768251c7e9979e899a569794b44c3728eaebd11d848e58e"}, + {file = "coverage-7.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dfd393094cd82ceb9b40df4c77976015a314b267d498268a076e940fe7be6b79"}, + {file = "coverage-7.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:182eb9ac3f2b4874a1f41b78b87db20b66da6b9cdc32737fbbf4fea0c35b23fc"}, + {file = "coverage-7.2.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bb1e77a9a311346294621be905ea8a2c30d3ad371fc15bb72e98bfcfae532df"}, + {file = "coverage-7.2.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca0f34363e2634deffd390a0fef1aa99168ae9ed2af01af4a1f5865e362f8623"}, + {file = "coverage-7.2.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55416d7385774285b6e2a5feca0af9652f7f444a4fa3d29d8ab052fafef9d00d"}, + {file = "coverage-7.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:06ddd9c0249a0546997fdda5a30fbcb40f23926df0a874a60a8a185bc3a87d93"}, + {file = "coverage-7.2.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:fff5aaa6becf2c6a1699ae6a39e2e6fb0672c2d42eca8eb0cafa91cf2e9bd312"}, + {file = "coverage-7.2.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ea53151d87c52e98133eb8ac78f1206498c015849662ca8dc246255265d9c3c4"}, + {file = "coverage-7.2.3-cp311-cp311-win32.whl", hash = "sha256:8f6c930fd70d91ddee53194e93029e3ef2aabe26725aa3c2753df057e296b925"}, + {file = "coverage-7.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:fa546d66639d69aa967bf08156eb8c9d0cd6f6de84be9e8c9819f52ad499c910"}, + {file = "coverage-7.2.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b2317d5ed777bf5a033e83d4f1389fd4ef045763141d8f10eb09a7035cee774c"}, + {file = "coverage-7.2.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be9824c1c874b73b96288c6d3de793bf7f3a597770205068c6163ea1f326e8b9"}, + {file = "coverage-7.2.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c3b2803e730dc2797a017335827e9da6da0e84c745ce0f552e66400abdfb9a1"}, + {file = "coverage-7.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f69770f5ca1994cb32c38965e95f57504d3aea96b6c024624fdd5bb1aa494a1"}, + {file = "coverage-7.2.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1127b16220f7bfb3f1049ed4a62d26d81970a723544e8252db0efde853268e21"}, + {file = "coverage-7.2.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:aa784405f0c640940595fa0f14064d8e84aff0b0f762fa18393e2760a2cf5841"}, + {file = "coverage-7.2.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3146b8e16fa60427e03884301bf8209221f5761ac754ee6b267642a2fd354c48"}, + {file = "coverage-7.2.3-cp37-cp37m-win32.whl", hash = "sha256:1fd78b911aea9cec3b7e1e2622c8018d51c0d2bbcf8faaf53c2497eb114911c1"}, + {file = "coverage-7.2.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0f3736a5d34e091b0a611964c6262fd68ca4363df56185902528f0b75dbb9c1f"}, + {file = "coverage-7.2.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:981b4df72c93e3bc04478153df516d385317628bd9c10be699c93c26ddcca8ab"}, + {file = "coverage-7.2.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c0045f8f23a5fb30b2eb3b8a83664d8dc4fb58faddf8155d7109166adb9f2040"}, + {file = "coverage-7.2.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f760073fcf8f3d6933178d67754f4f2d4e924e321f4bb0dcef0424ca0215eba1"}, + {file = "coverage-7.2.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c86bd45d1659b1ae3d0ba1909326b03598affbc9ed71520e0ff8c31a993ad911"}, + {file = "coverage-7.2.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:172db976ae6327ed4728e2507daf8a4de73c7cc89796483e0a9198fd2e47b462"}, + {file = "coverage-7.2.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d2a3a6146fe9319926e1d477842ca2a63fe99af5ae690b1f5c11e6af074a6b5c"}, + {file = "coverage-7.2.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f649dd53833b495c3ebd04d6eec58479454a1784987af8afb77540d6c1767abd"}, + {file = "coverage-7.2.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7c4ed4e9f3b123aa403ab424430b426a1992e6f4c8fd3cb56ea520446e04d152"}, + {file = "coverage-7.2.3-cp38-cp38-win32.whl", hash = "sha256:eb0edc3ce9760d2f21637766c3aa04822030e7451981ce569a1b3456b7053f22"}, + {file = "coverage-7.2.3-cp38-cp38-win_amd64.whl", hash = "sha256:63cdeaac4ae85a179a8d6bc09b77b564c096250d759eed343a89d91bce8b6367"}, + {file = "coverage-7.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:20d1a2a76bb4eb00e4d36b9699f9b7aba93271c9c29220ad4c6a9581a0320235"}, + {file = "coverage-7.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ea748802cc0de4de92ef8244dd84ffd793bd2e7be784cd8394d557a3c751e21"}, + {file = "coverage-7.2.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21b154aba06df42e4b96fc915512ab39595105f6c483991287021ed95776d934"}, + {file = "coverage-7.2.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd214917cabdd6f673a29d708574e9fbdb892cb77eb426d0eae3490d95ca7859"}, + {file = "coverage-7.2.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c2e58e45fe53fab81f85474e5d4d226eeab0f27b45aa062856c89389da2f0d9"}, + {file = "coverage-7.2.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87ecc7c9a1a9f912e306997ffee020297ccb5ea388421fe62a2a02747e4d5539"}, + {file = "coverage-7.2.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:387065e420aed3c71b61af7e82c7b6bc1c592f7e3c7a66e9f78dd178699da4fe"}, + {file = "coverage-7.2.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ea3f5bc91d7d457da7d48c7a732beaf79d0c8131df3ab278e6bba6297e23c6c4"}, + {file = "coverage-7.2.3-cp39-cp39-win32.whl", hash = "sha256:ae7863a1d8db6a014b6f2ff9c1582ab1aad55a6d25bac19710a8df68921b6e30"}, + {file = "coverage-7.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f04becd4fcda03c0160d0da9c8f0c246bc78f2f7af0feea1ec0930e7c93fa4a"}, + {file = "coverage-7.2.3-pp37.pp38.pp39-none-any.whl", hash = "sha256:965ee3e782c7892befc25575fa171b521d33798132692df428a09efacaffe8d0"}, + {file = "coverage-7.2.3.tar.gz", hash = "sha256:d298c2815fa4891edd9abe5ad6e6cb4207104c7dd9fd13aea3fdebf6f9b91259"}, ] [package.dependencies] @@ -283,20 +264,20 @@ files = [ [[package]] name = "etils" -version = "1.1.0" +version = "1.2.0" description = "Collection of common python utils" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "etils-1.1.0-py3-none-any.whl", hash = "sha256:b994873778574aa803ffe6e84aaa2a3880f44f776a2bd542a28c62bf4ffa0db5"}, - {file = "etils-1.1.0.tar.gz", hash = "sha256:7a2a4950779a281ef4c7e595ae215990b15c1d8c6f8b0a27850092af5ad2c641"}, + {file = "etils-1.2.0-py3-none-any.whl", hash = "sha256:c6585069b387fdbeed6a2c571b8bcf312ecdb577c95065461e5fad9ed1973989"}, + {file = "etils-1.2.0.tar.gz", hash = "sha256:29d369e2dcf43960d9ee338330579d04badd606c88f015f4e1a38d3adbe446d8"}, ] [package.extras] all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] array-types = ["etils[enp]"] -dev = ["chex", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] +dev = ["chex", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] eapp = ["absl-py", "etils[epy]", "simple_parsing"] ecolab = ["etils[enp]", "etils[epy]", "jupyter", "mediapy", "numpy"] edc = ["etils[epy]", "typing_extensions"] @@ -327,30 +308,30 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.10.0" +version = "3.11.0" description = "A platform independent file lock." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "filelock-3.10.0-py3-none-any.whl", hash = "sha256:e90b34656470756edf8b19656785c5fea73afa1953f3e1b0d645cef11cab3182"}, - {file = "filelock-3.10.0.tar.gz", hash = "sha256:3199fd0d3faea8b911be52b663dfccceb84c95949dd13179aa21436d1a79c4ce"}, + {file = "filelock-3.11.0-py3-none-any.whl", hash = "sha256:f08a52314748335c6460fc8fe40cd5638b85001225db78c2aa01c8c0db83b318"}, + {file = "filelock-3.11.0.tar.gz", hash = "sha256:3618c0da67adcc0506b015fd11ef7faf1b493f0b40d87728e19986b536890c37"}, ] [package.extras] -docs = ["furo (>=2022.12.7)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.1)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-timeout (>=2.1)"] +docs = ["furo (>=2023.3.27)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] [[package]] name = "flax" -version = "0.6.7" +version = "0.6.8" description = "Flax: A neural network library for JAX designed for flexibility" category = "dev" optional = false python-versions = "*" files = [ - {file = "flax-0.6.7-py3-none-any.whl", hash = "sha256:b4f11553197d3d47385363be24a860ea685ddd968e0c1ab6f6a83c9ba4e93b01"}, - {file = "flax-0.6.7.tar.gz", hash = "sha256:985d668e0d774d2c79fd93f5217793ea538a98861ef1dc7396c0879b0020fbf3"}, + {file = "flax-0.6.8-py3-none-any.whl", hash = "sha256:221225804c263e39fe3cc8f754dc4192597cb0f063926b2338ea6563604747ed"}, + {file = "flax-0.6.8.tar.gz", hash = "sha256:bf1f81dd5dfbb10c603490531a86b1174ebbc38e5c5e8116a98115c135194c10"}, ] [package.dependencies] @@ -366,18 +347,18 @@ typing-extensions = ">=4.1.1" [package.extras] all = ["matplotlib"] -testing = ["atari-py (==0.2.5)", "clu", "gym (==0.18.3)", "jaxlib", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist (==1.34.0)", "pytype", "sentencepiece", "tensorflow", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] +testing = ["atari-py (==0.2.5)", "clu", "einops", "gym (==0.18.3)", "jaxlib", "jraph (>=0.0.6dev0)", "ml-collections", "mypy", "nbstripout", "opencv-python", "pytest", "pytest-cov", "pytest-custom-exit-code", "pytest-xdist (==1.34.0)", "pytype", "sentencepiece", "tensorflow", "tensorflow-datasets", "tensorflow-text (>=2.11.0)", "torch"] [[package]] name = "identify" -version = "2.5.21" +version = "2.5.22" description = "File identification library for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "identify-2.5.21-py2.py3-none-any.whl", hash = "sha256:69edcaffa8e91ae0f77d397af60f148b6b45a8044b2cc6d99cafa5b04793ff00"}, - {file = "identify-2.5.21.tar.gz", hash = "sha256:7671a05ef9cfaf8ff63b15d45a91a1147a03aaccb2976d4e9bd047cbbc508471"}, + {file = "identify-2.5.22-py2.py3-none-any.whl", hash = "sha256:f0faad595a4687053669c112004178149f6c326db71ee999ae4636685753ad2f"}, + {file = "identify-2.5.22.tar.gz", hash = "sha256:f7a93d6cf98e29bd07663c60728e7a4057615068d7a639d132dc883b2d54d31e"}, ] [package.extras] @@ -434,55 +415,61 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] [[package]] name = "jax" -version = "0.4.6" +version = "0.4.8" description = "Differentiate, compile, and transform Numpy code." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jax-0.4.6.tar.gz", hash = "sha256:d06ea8fba4ed315ec55110396058cb48c8edb2ab0b412f28c8a123beee9e58ab"}, + {file = "jax-0.4.8.tar.gz", hash = "sha256:08116481f7336db16c24812bfb5e6f9786915f4c2f6ff4028331fa69e7535202"}, ] [package.dependencies] -numpy = ">=1.20" +ml_dtypes = ">=0.0.3" +numpy = ">=1.21" opt_einsum = "*" -scipy = ">=1.5" +scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.4)"] -cpu = ["jaxlib (==0.4.6)"] -cuda = ["jaxlib (==0.4.6+cuda11.cudnn86)"] -cuda11-cudnn82 = ["jaxlib (==0.4.6+cuda11.cudnn82)"] -cuda11-cudnn86 = ["jaxlib (==0.4.6+cuda11.cudnn86)"] -minimum-jaxlib = ["jaxlib (==0.4.4)"] -tpu = ["jaxlib (==0.4.6)", "libtpu-nightly (==0.1.dev20230309)", "requests"] +ci = ["jaxlib (==0.4.7)"] +cpu = ["jaxlib (==0.4.7)"] +cuda = ["jaxlib (==0.4.7+cuda11.cudnn86)"] +cuda11-cudnn82 = ["jaxlib (==0.4.7+cuda11.cudnn82)"] +cuda11-cudnn86 = ["jaxlib (==0.4.7+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.7+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.7+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.6)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.7+cuda12.cudnn88)"] +cuda12-pip = ["jaxlib (==0.4.7+cuda12.cudnn88)", "nvidia-cublas-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.7)"] +tpu = ["jaxlib (==0.4.7)", "libtpu-nightly (==0.1.dev20230327)", "requests"] [[package]] name = "jaxlib" -version = "0.4.6" +version = "0.4.7" description = "XLA library for JAX" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jaxlib-0.4.6-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c42ea44671f1e560d63f742c787a65744f1efd0a20bbfe177e9d3e8bd7cece92"}, - {file = "jaxlib-0.4.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1393968b0f808c1769195990f1ea138903bc4012bdffb850eecd10e113b8fca8"}, - {file = "jaxlib-0.4.6-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e18e3d5fd5d1aee94bd97791c7157ea7fd682f5eb8a04a8b1a3b0ed011175892"}, - {file = "jaxlib-0.4.6-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:f7233463e1f79b330d3a1e12629e1bbc334acf6f5be22a0af244a2ed544afdfe"}, - {file = "jaxlib-0.4.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2f504f0d48a1e727322aa5baae0ec0405fc5fcb5e4135bb15740978535b5e0"}, - {file = "jaxlib-0.4.6-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:34b0b4e41e185fba36b81cf68d4979503afba4640bf29b7f6709b17b3c3c55bc"}, - {file = "jaxlib-0.4.6-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:2949b6b6b77f296982b42ad2c6350526baf47b0f105118a65a9e9b2093de6572"}, - {file = "jaxlib-0.4.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3be2b70104f9547a3281e5c06dfafe1be27c4927d6b62b69a55d26977cd03e15"}, - {file = "jaxlib-0.4.6-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:554fc7300c61d76b5996145bd33a8dd19d60c49e57ad38686057d719b1d69d38"}, - {file = "jaxlib-0.4.6-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:cdfbff50bae46065d2fac6250260077e2c554df52252c45cb5ca949bed378b6f"}, - {file = "jaxlib-0.4.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:399d83e16a35d66693b27951b87be566d91e47c7f4ac1fc5a362536a7b9c29cc"}, - {file = "jaxlib-0.4.6-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4079c42247db33c69f10710b6a5a570ed89d773e0e549612fadba3d06fe4773c"}, + {file = "jaxlib-0.4.7-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:63c2890978e8646516db3d8a680b43d2bed8b63543a70556391f589a261bd85f"}, + {file = "jaxlib-0.4.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c16f922507277d5630e81d9c1a4974366a27aad5230d645d063bc2011564d01"}, + {file = "jaxlib-0.4.7-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:da88382e6487805974cea6facc61ba92b5828a7a1f2dd80f762c487d873a2b47"}, + {file = "jaxlib-0.4.7-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:022b216036c009989d4c0683538820c19247215bb99fdd35c7bf32838d596be6"}, + {file = "jaxlib-0.4.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d0f1d3b6ef6c68013898cca958ab1507d6809b523275037efbdb9aaaaab158ba"}, + {file = "jaxlib-0.4.7-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:0ae7178c33460822d9d8d03718cba395e02e6bac2402709c35826c94f0c9cc7b"}, + {file = "jaxlib-0.4.7-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:ea07605e37d2b4e25f3c639e0d22ab4605fbc1a10ea918fd14ce09077bdaffb6"}, + {file = "jaxlib-0.4.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:48b85d3c8923b1619ddf8cbf14c4e4daf6919796d8aa9d006ce2a085e8202930"}, + {file = "jaxlib-0.4.7-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:a860f2990c97bee5ffcdbb5111751591e5e7a66d5e32b4f6d9e6aa14ac82bf27"}, + {file = "jaxlib-0.4.7-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c78dc2b6fa1c92ead137a23d1bd3e10d04c58b268e77eca811502abac05b2b19"}, + {file = "jaxlib-0.4.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1f3726e374d0d6fcc14da540b71b758d37356c6726f0f4b48e2f5530a5f8769"}, + {file = "jaxlib-0.4.7-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:d4629205dbe342153941db5f69c4a1bfe35fd8d2947aebe34f4dff3771d3fff7"}, ] [package.dependencies] -numpy = ">=1.20" -scipy = ">=1.5" +ml-dtypes = ">=0.0.3" +numpy = ">=1.21" +scipy = ">=1.7" [[package]] name = "markdown-it-py" @@ -521,6 +508,43 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "ml-dtypes" +version = "0.0.4" +description = "" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ml_dtypes-0.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a03c5acc55a878fac190d428ef01438f930cbef3fb8625c8c8fd2e3adc277607"}, + {file = "ml_dtypes-0.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e600aa70a9f8ee85c9488eb14852124c878ec824c3c7996d2d82010655eabfe"}, + {file = "ml_dtypes-0.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74c1fb29d2e586f643fb1a70b1dffe9fc35bc3ad8c76ec0797b2bf9f7ac128b"}, + {file = "ml_dtypes-0.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:b3f49901eb42cac259156edc17d4c1922ac47ddd1fe3c05169f445135a07319c"}, + {file = "ml_dtypes-0.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52aaa9318e2a4ec65a6bc4842df3442a9cfa00a9b8365a08e0370b0dfefc3a5a"}, + {file = "ml_dtypes-0.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db9912d50466d386a4016b16f889722183f6d6c03d9e478fdf62f41e50de0059"}, + {file = "ml_dtypes-0.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ece1269b5311489e26b3f3181d498b8829042f380cd160d7fe02f2393f69a71"}, + {file = "ml_dtypes-0.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:68d2e6c83c762aa6d476ea715ce6b2ac67f519c242cfe93d7a49cb76a83f6650"}, + {file = "ml_dtypes-0.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:85085f9dac85b1eee5f7d2044c47bb3df72abc7785d38d176744fde5782b76ce"}, + {file = "ml_dtypes-0.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75ef23de72daf5efcc99799dfaa387386b79502a123909b0d3098ef84ffa6fa"}, + {file = "ml_dtypes-0.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b651fa1f91ce83cf037db202cd2601ac9b649016ec8593459c0295e613bf47"}, + {file = "ml_dtypes-0.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:b28c6b7831fa2cbb3169ed3053f10fb11d0415e2f250b893eb874e3af747a1f3"}, + {file = "ml_dtypes-0.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:23ff15cd9ba61cc42287097c30ae6841facd6dc14cc252f977d6430b8cd6eccc"}, + {file = "ml_dtypes-0.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5b148131da64f85053b79380cf34471eb869f7c027e2198a0c86d5e6fc9531f"}, + {file = "ml_dtypes-0.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc64866c1848999fab6f4a2938e769aed95b964085ebdcd7cd45e350192e457"}, + {file = "ml_dtypes-0.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:e64869be11c830736c40513c47918c421a8385243846f1e8fd838793d866aa87"}, + {file = "ml_dtypes-0.0.4.tar.gz", hash = "sha256:45623c738d477d7a0f3f8e4c94998dc49025202c520e62e27f0ef688db2f696f"}, +] + +[package.dependencies] +numpy = [ + {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.23.3", markers = "python_version > \"3.10\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "msgpack" version = "1.0.5" @@ -606,6 +630,18 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nest-asyncio" +version = "1.5.6" +description = "Patch asyncio to allow nested event loops" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.5.6-py3-none-any.whl", hash = "sha256:b9a953fb40dceaa587d109609098db21900182b16440652454a146cffb06e8b8"}, + {file = "nest_asyncio-1.5.6.tar.gz", hash = "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"}, +] + [[package]] name = "nodeenv" version = "1.7.0" @@ -700,14 +736,14 @@ typing-extensions = ">=3.10.0" [[package]] name = "orbax" -version = "0.1.4" +version = "0.1.7" description = "Orbax" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "orbax-0.1.4-py3-none-any.whl", hash = "sha256:fe074bb1c0dfa302e03613c18febdd81191438d21cf393e962e3042a49819b08"}, - {file = "orbax-0.1.4.tar.gz", hash = "sha256:b2077e8814519393a4179715b6420b36f93f1bccff5fafc8e5a5307d890d4dd0"}, + {file = "orbax-0.1.7-py3-none-any.whl", hash = "sha256:67c7ce52b5476202af84977e8db03dede6c009b5d1f1095acfc175578038449b"}, + {file = "orbax-0.1.7.tar.gz", hash = "sha256:2517f566134db6597d2850450b7f486efd24bf24962bc4881007f4cc8e978b37"}, ] [package.dependencies] @@ -718,6 +754,7 @@ importlib_resources = "*" jax = ">=0.4.6" jaxlib = "*" msgpack = "*" +nest_asyncio = "*" numpy = "*" pyyaml = "*" tensorstore = ">=0.1.20" @@ -752,19 +789,19 @@ files = [ [[package]] name = "platformdirs" -version = "3.1.1" +version = "3.2.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-3.1.1-py3-none-any.whl", hash = "sha256:e5986afb596e4bb5bde29a79ac9061aa955b94fca2399b7aaac4090860920dd8"}, - {file = "platformdirs-3.1.1.tar.gz", hash = "sha256:024996549ee88ec1a9aa99ff7f8fc819bb59e2c3477b410d90a16d32d6e707aa"}, + {file = "platformdirs-3.2.0-py3-none-any.whl", hash = "sha256:ebe11c0d7a805086e99506aa331612429a72ca7cd52a1f0d277dc4adc20cb10e"}, + {file = "platformdirs-3.2.0.tar.gz", hash = "sha256:d5b638ca397f25f979350ff789db335903d7ea010ab28903f57b27e1b16c2b08"}, ] [package.extras] docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] [[package]] name = "pluggy" @@ -784,14 +821,14 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pre-commit" -version = "3.1.1" +version = "3.2.2" description = "A framework for managing and maintaining multi-language pre-commit hooks." category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "pre_commit-3.1.1-py2.py3-none-any.whl", hash = "sha256:b80254e60668e1dd1f5c03a1c9e0413941d61f568a57d745add265945f65bfe8"}, - {file = "pre_commit-3.1.1.tar.gz", hash = "sha256:d63e6537f9252d99f65755ae5b79c989b462d511ebbc481b561db6a297e1e865"}, + {file = "pre_commit-3.2.2-py2.py3-none-any.whl", hash = "sha256:0b4210aea813fe81144e87c5a291f09ea66f199f367fa1df41b55e1d26e1e2b4"}, + {file = "pre_commit-3.2.2.tar.gz", hash = "sha256:5b808fcbda4afbccf6d6633a56663fed35b6c2bc08096fd3d47ce197ac351d9d"}, ] [package.dependencies] @@ -818,18 +855,17 @@ plugins = ["importlib-metadata"] [[package]] name = "pytest" -version = "7.2.2" +version = "7.3.0" description = "pytest: simple powerful testing with Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"}, - {file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"}, + {file = "pytest-7.3.0-py3-none-any.whl", hash = "sha256:933051fa1bfbd38a21e73c3960cebdad4cf59483ddba7696c48509727e17f201"}, + {file = "pytest-7.3.0.tar.gz", hash = "sha256:58ecc27ebf0ea643ebfdf7fb1249335da761a00c9f955bcd922349bcb68ee57d"}, ] [package.dependencies] -attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" @@ -838,7 +874,7 @@ pluggy = ">=0.12,<2.0" tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] [[package]] name = "pytest-cov" @@ -911,14 +947,14 @@ files = [ [[package]] name = "rich" -version = "13.3.2" +version = "13.3.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" category = "dev" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.3.2-py3-none-any.whl", hash = "sha256:a104f37270bf677148d8acb07d33be1569eeee87e2d1beb286a4e9113caf6f2f"}, - {file = "rich-13.3.2.tar.gz", hash = "sha256:91954fe80cfb7985727a467ca98a7618e5dd15178cc2da10f553b36a93859001"}, + {file = "rich-13.3.3-py3-none-any.whl", hash = "sha256:540c7d6d26a1178e8e8b37e9ba44573a3cd1464ff6348b99ee7061b95d1c6333"}, + {file = "rich-13.3.3.tar.gz", hash = "sha256:dc84400a9d842b3a9c5ff74addd8eb798d155f36c1c91303888e0a66850d2a15"}, ] [package.dependencies] @@ -970,14 +1006,14 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "setuptools" -version = "67.6.0" +version = "67.6.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-67.6.0-py3-none-any.whl", hash = "sha256:b78aaa36f6b90a074c1fa651168723acbf45d14cb1196b6f02c0fd07f17623b2"}, - {file = "setuptools-67.6.0.tar.gz", hash = "sha256:2ee892cd5f29f3373097f5a814697e397cf3ce313616df0af11231e2ad118077"}, + {file = "setuptools-67.6.1-py3-none-any.whl", hash = "sha256:e728ca814a823bf7bf60162daf9db95b93d532948c4c0bea762ce62f60189078"}, + {file = "setuptools-67.6.1.tar.gz", hash = "sha256:257de92a9d50a60b8e22abfcbb771571fde0dbf3ec234463212027a4eeecbe9a"}, ] [package.extras] @@ -987,29 +1023,29 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( [[package]] name = "tensorstore" -version = "0.1.33" +version = "0.1.35" description = "Read and write large, multi-dimensional arrays" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "tensorstore-0.1.33-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:af284933437d0c292b609b63190759090d50a0b4d802d8c309ce7647377b9810"}, - {file = "tensorstore-0.1.33-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:abba0438892eda42a729184fa812b67f91980556af8f8ed1d819a654cc97e341"}, - {file = "tensorstore-0.1.33-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b07680fd0de4ea7787567ae2ee2a9e829262286105f758b6213db5592269b178"}, - {file = "tensorstore-0.1.33-cp310-cp310-win_amd64.whl", hash = "sha256:2629b39c2555f94e5e32776459e6c879ad3ceded5ed2b9963914f30b90ae3af7"}, - {file = "tensorstore-0.1.33-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:c7645ec63234bbed51a5f310cfd90579d229de9d2ab089573ff5af48316145c0"}, - {file = "tensorstore-0.1.33-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:eddb82fcb91844373cd414004cfbee878930cd9662ee784f0f5df8e08248b178"}, - {file = "tensorstore-0.1.33-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2767fc729f2ea3734f3b1ef8814d57b69639f3b2648d8ab5361b985a158caf6a"}, - {file = "tensorstore-0.1.33-cp311-cp311-win_amd64.whl", hash = "sha256:4b90ea9ed8ca733954a93497a3d166aa917565a4116f52bfc815d982ec210fdf"}, - {file = "tensorstore-0.1.33-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:b54974fee556208b4a35aab12aba979714674e54f9e625c59f223e861355a1e2"}, - {file = "tensorstore-0.1.33-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c0086f36a0e00869fb88deb662b08786ccc4954ca1c6a02b2f532309acbe10e0"}, - {file = "tensorstore-0.1.33-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7046b345c545fcd05c0dbe620bf4f2cd3c06d4194d16b482404ea5fb267e21d7"}, - {file = "tensorstore-0.1.33-cp38-cp38-win_amd64.whl", hash = "sha256:c9b6744372ab31a4c99c405685fa9ea9c64357a3cca3fbdc7d71280c10d8c9bc"}, - {file = "tensorstore-0.1.33-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:8ff7e2d1a5ee6e08e9780ab2e92c283a969e8dd06fd34eb1e893de4ad00e5f52"}, - {file = "tensorstore-0.1.33-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1ae9efafc9892435b34bea3e5cfced965f5889d36cfbd00769b7bf399d3602a7"}, - {file = "tensorstore-0.1.33-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5247a5eccebd6d1780b7ba54bf1f7c0f408acc7683bcf84cd006010649540b67"}, - {file = "tensorstore-0.1.33-cp39-cp39-win_amd64.whl", hash = "sha256:c3f7181712f61efa62149b37f34fcc51b2243fbc92b47d4b3353222f5b8620b7"}, - {file = "tensorstore-0.1.33.tar.gz", hash = "sha256:471d449b9dcbfe5f08691ccf95f3b61188d8306540dd6112d9a689f8c3b6f6e8"}, + {file = "tensorstore-0.1.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:4e1b4210b777c4a585a183bdd435bfa8aa5628c46075cb64adcd9b4bdd124e35"}, + {file = "tensorstore-0.1.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:285cfc4816bb0cc305cddc11f25f81a14bd84af0c8bbd39e42c81413c0bf242e"}, + {file = "tensorstore-0.1.35-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51a8578518362daee85e39162bf870faf5545868cd53fe99ede0eda1fad288f6"}, + {file = "tensorstore-0.1.35-cp310-cp310-win_amd64.whl", hash = "sha256:a0318ea4afd4f2c00ce2dd4b540acb31c45a260bda94ae7e4340a1a4d28c6848"}, + {file = "tensorstore-0.1.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7173f451c1b970230f57b0cdefdd692940ee457d4982a696d00aba163a7fee9a"}, + {file = "tensorstore-0.1.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d2fc9d8e5aaa54c538434592eaf88f7dfa6773fb35a960cc4cbe20bef55092d7"}, + {file = "tensorstore-0.1.35-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f34fc72a9ceff1c1666b6c85cc86858adbd10ff1c3c6b98c7788bbea9161a6a6"}, + {file = "tensorstore-0.1.35-cp311-cp311-win_amd64.whl", hash = "sha256:216bf4c00ec4aabf699d2a54ee9311f2fb19a2a3a904d7abb2194572af2f8384"}, + {file = "tensorstore-0.1.35-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:932f96d35ebdf0e4650bd9cd089319e9c5723d2aaf3f65123a821fc3b04ca4ac"}, + {file = "tensorstore-0.1.35-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c9d0ea7888a88e5d892894ddf14f76d37665d0adbc2d861c559d0d6d5eaac20d"}, + {file = "tensorstore-0.1.35-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd1406fd325517331887b0b2533e27e17e549196418eb087730bafa6fffaa236"}, + {file = "tensorstore-0.1.35-cp38-cp38-win_amd64.whl", hash = "sha256:99ad4e577249c2dfb07a501d78a31e29d6b9e53752384d58782ad53e6014ac41"}, + {file = "tensorstore-0.1.35-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:9acf9e9a7b3117881ec11f26930d0fee89cce6bb3d81056c15317f7cf2c0c1e1"}, + {file = "tensorstore-0.1.35-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d742468c6aec6a1dcf3ec164694c9827eb498b2c701b8020fb4a56446f9bbc1a"}, + {file = "tensorstore-0.1.35-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8011a556b4cbbc547e9365cf833ce7a5fb9fdbaa3f86a92665b78a78b78131"}, + {file = "tensorstore-0.1.35-cp39-cp39-win_amd64.whl", hash = "sha256:04c383af4a17e238fffdda7abc06b66c8583554d523cd721f2011bbb7a715327"}, + {file = "tensorstore-0.1.35.tar.gz", hash = "sha256:93db16e2f448cad716628640d3b73b87d9b259ae8ba1741a82108aef14e427c6"}, ] [package.dependencies] @@ -1043,7 +1079,7 @@ files = [ name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1091,4 +1127,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "513129af3b9cdea678c061e404fa7047dc00815950be820214d2ab4d4b532e76" +content-hash = "d58bbea0b379f12a18195c9ac4b4db16e916ed188687f74198393ccc22ea20c5" diff --git a/pyproject.toml b/pyproject.toml index 1237f1f..0cad5d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,12 +5,13 @@ description = "" authors = ["Cristian Garcia "] license = "MIT" readme = "README.md" -packages = [{include = "simple_pytree"}] +packages = [{ include = "simple_pytree" }] [tool.poetry.dependencies] python = ">=3.8,<3.12" jax = "*" jaxlib = "*" +typing-extensions = "*" [tool.poetry.group.dev.dependencies] @@ -24,3 +25,6 @@ flax = "*" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.coverage.report] +exclude_lines = ["@tp.overload"] diff --git a/simple_pytree/__init__.py b/simple_pytree/__init__.py index 351dcf4..bd8e792 100644 --- a/simple_pytree/__init__.py +++ b/simple_pytree/__init__.py @@ -1,5 +1,6 @@ __version__ = "0.1.7" -from .pytree import Pytree, field, static_field +from .dataclass import dataclass, field, static_field +from .pytree import Pytree, PytreeMeta -__all__ = ["Pytree", "field", "static_field"] +__all__ = ["Pytree", "PytreeMeta", "dataclass", "field", "static_field"] diff --git a/simple_pytree/dataclass.py b/simple_pytree/dataclass.py new file mode 100644 index 0000000..044768e --- /dev/null +++ b/simple_pytree/dataclass.py @@ -0,0 +1,103 @@ +import dataclasses +import typing as tp + +import typing_extensions as tpe + +A = tp.TypeVar("A") + + +def field( + *, + default: tp.Any = dataclasses.MISSING, + pytree_node: bool = True, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if "pytree_node" in metadata: + raise ValueError("'pytree_node' found in metadata") + + metadata["pytree_node"] = pytree_node + + return dataclasses.field( # type: ignore + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def static_field( + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + return field( + default=default, + pytree_node=False, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +@tp.overload +def dataclass(cls: tp.Type[A]) -> tp.Type[A]: + ... + + +@tp.overload +def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Callable[[tp.Type[A]], tp.Type[A]]: + ... + + +@tpe.dataclass_transform(field_specifiers=(field, static_field, dataclasses.field)) +def dataclass( + cls: tp.Optional[tp.Type[A]] = None, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]: + decorator = dataclasses.dataclass( + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + ) + + if cls is None: + return decorator + + return decorator(cls) diff --git a/simple_pytree/pytree.py b/simple_pytree/pytree.py index 75cbf7c..5858e62 100644 --- a/simple_pytree/pytree.py +++ b/simple_pytree/pytree.py @@ -11,60 +11,6 @@ P = tp.TypeVar("P", bound="Pytree") -def field( - default: tp.Any = dataclasses.MISSING, - *, - pytree_node: bool = True, - default_factory: tp.Any = dataclasses.MISSING, - init: bool = True, - repr: bool = True, - hash: tp.Optional[bool] = None, - compare: bool = True, - metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, -): - if metadata is None: - metadata = {} - else: - metadata = dict(metadata) - - if "pytree_node" in metadata: - raise ValueError("'pytree_node' found in metadata") - - metadata["pytree_node"] = pytree_node - - return dataclasses.field( - default=default, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - metadata=metadata, - ) - - -def static_field( - default: tp.Any = dataclasses.MISSING, - *, - default_factory: tp.Any = dataclasses.MISSING, - init: bool = True, - repr: bool = True, - hash: tp.Optional[bool] = None, - compare: bool = True, - metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, -): - return field( - default=default, - pytree_node=False, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - metadata=metadata, - ) - - class PytreeMeta(ABCMeta): def __call__(self: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: obj: P = self.__new__(self, *args, **kwargs) diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 36be36b..7d09247 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -1,11 +1,10 @@ -import dataclasses from typing import Generic, TypeVar import jax import pytest from flax import serialization -from simple_pytree import Pytree, field, static_field +from simple_pytree import Pytree, dataclass, field, static_field class TestPytree: @@ -36,10 +35,10 @@ def __init__(self, y) -> None: pytree.x = 4 def test_immutable_pytree_dataclass(self): - @dataclasses.dataclass(frozen=True) + @dataclass(frozen=True) class Foo(Pytree): y: int = field() - x: int = static_field(2) + x: int = static_field(default=2) pytree = Foo(y=3) @@ -58,7 +57,7 @@ class Foo(Pytree): pytree.x = 4 def test_jit(self): - @dataclasses.dataclass + @dataclass class Foo(Pytree): a: int b: int = static_field() @@ -79,7 +78,7 @@ def __init__(self, a, b): self.a = a self.b = b - @dataclasses.dataclass + @dataclass class Foo(Pytree): bar: Bar c: int @@ -125,15 +124,15 @@ def __init__(self, x: T): MyClass[int] def test_key_paths(self): - @dataclasses.dataclass + @dataclass class Bar(Pytree): a: int = 1 - b: int = static_field(2) + b: int = static_field(default=2) - @dataclasses.dataclass + @dataclass class Foo(Pytree): x: int = 3 - y: int = static_field(4) + y: int = static_field(default=4) z: Bar = field(default_factory=Bar) foo = Foo() @@ -171,12 +170,12 @@ class Foo(Pytree): Foo().replace(y=1) def test_dataclass_inheritance(self): - @dataclasses.dataclass + @dataclass class A(Pytree): a: int = 1 - b: int = static_field(2) + b: int = static_field(default=2) - @dataclasses.dataclass + @dataclass class B(A): c: int = 3 @@ -224,10 +223,10 @@ def __init__(self, y) -> None: assert pytree.x == 4 def test_pytree_dataclass(self): - @dataclasses.dataclass + @dataclass class Foo(Pytree, mutable=True): y: int = field() - x: int = static_field(2) + x: int = static_field(default=2) pytree: Foo = Foo(y=3)