Skip to content

Commit

Permalink
Now using wadler_lindig pprint library.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 8, 2025
1 parent 9f5a3fa commit c39abe7
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run-tests:
strategy:
matrix:
python-version: [ 3.9, 3.11 ]
python-version: [ "3.10", "3.12" ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
pip install jaxtyping
```

Requires Python 3.9+.
Requires Python 3.10+.

JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.

Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jaxtyping is a library providing type annotations **and runtime type-checking**
pip install jaxtyping
```

Requires Python 3.9+.
Requires Python 3.10+.

JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.

Expand Down
6 changes: 2 additions & 4 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import importlib.util
import typing
import warnings
from typing import Union
from typing import TypeAlias, Union

from ._array_types import (
AbstractArray as AbstractArray,
Expand All @@ -43,8 +43,6 @@


if typing.TYPE_CHECKING:
import typing_extensions

from jax import Array as Array
from jax.tree_util import PyTreeDef as PyTreeDef
from jax.typing import ArrayLike as ArrayLike, DTypeLike as DTypeLike
Expand Down Expand Up @@ -90,7 +88,7 @@
)

# Set up to deliberately confuse a static type checker.
PyTree: typing_extensions.TypeAlias = getattr(typing, "foo" + "bar")
PyTree: TypeAlias = getattr(typing, "foo" + "bar")
# What's going on with this madness?
#
# At static-type-checking-time, we want `PyTree` to be a type for which both
Expand Down
33 changes: 20 additions & 13 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,11 @@
get_type_hints,
NoReturn,
overload,
ParamSpec,
TypeVar,
Union,
)


try:
from typing import ParamSpec
except ImportError:
# Python < 3.10
from typing_extensions import ParamSpec


from jaxtyping import AbstractArray

from ._config import config
Expand Down Expand Up @@ -847,12 +840,12 @@ def _pformat(x, short_self: bool):
# No performance concerns from delayed imports -- this is only used when we're about
# to raise an error anyway.
try:
# TODO(kidger): this is pretty ugly. We have a circular dependency
# equinox->jaxtyping->equinox. We could consider moving all the pretty-printing
# code from equinox into jaxtyping maybe? Or into some shared dependency?
# If we can, use `eqx.tree_pformat`, which wraps `wadler_lindig.pformat` with
# understanding of a few other JAX-specific things.
import equinox as eqx

pformat = eqx.tree_pformat

if short_self:
try:
self = x["self"]
Expand All @@ -862,9 +855,23 @@ def _pformat(x, short_self: bool):
is_self = lambda y: y is self
pformat = ft.partial(pformat, truncate_leaf=is_self)
except Exception:
import pprint
# Failing that fall back to `wadler_lindig.pformat` directly.
import wadler_lindig

pformat = wadler_lindig.pformat

if short_self:
try:
self = x["self"]
except KeyError:
pass
else:

def custom(obj):
if obj is self:
return wadler_lindig.TextDoc(f"{type(obj).__name__}(...)")

pformat = ft.partial(pprint.pformat, indent=2, compact=True)
pformat = ft.partial(pformat, custom=custom)
try:
return pformat(x)
except Exception:
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "jaxtyping"
version = "0.2.36"
description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays."
readme = "README.md"
requires-python =">=3.9"
requires-python =">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Patrick Kidger", email = "[email protected]"},
Expand All @@ -23,9 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/google/jaxtyping" }
dependencies = [
"typing_extensions; python_version < '3.10'"
]
dependencies = ["wadler_lindig>=0.1.0"]
entry-points = {pytest11 = {jaxtyping = "jaxtyping._pytest_plugin"}}

[build-system]
Expand Down

0 comments on commit c39abe7

Please sign in to comment.