diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index ec3f0aadb..77f125b12 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -10,7 +10,7 @@ from jax import config config.update("jax_enable_x64", True) -from dataclasses import dataclass +from dataclasses import dataclass, field from jax import hessian from jax import config @@ -195,7 +195,6 @@ def dataset_3d(pos, vel): # %% -from dataclasses import field @dataclass