From 102e499d61471c169277f62eda813bce04e7df65 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 7 Mar 2024 17:46:04 +0100 Subject: [PATCH] Fixes #188. --- jaxtyping/_decorator.py | 6 +++++- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 7a92a9a..e4ee19c 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -189,7 +189,11 @@ def f(...): ... """ global _tb_flag - if _tb_flag and importlib.util.find_spec("jax._src.traceback_util") is not None: + if ( + _tb_flag + and importlib.util.find_spec("jax") is not None + and importlib.util.find_spec("jax._src.traceback_util") is not None + ): import jax._src.traceback_util as traceback_util traceback_util.register_exclusion(__file__) diff --git a/pyproject.toml b/pyproject.toml index 19d30c4..e50b329 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxtyping" -version = "0.2.27" +version = "0.2.28" description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." readme = "README.md" requires-python ="~=3.9"