Skip to content

Commit

Permalink
Now statically working with dataclasses correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 2, 2025
1 parent b2f16e8 commit 2d311b5
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
NoReturn,
overload,
TypeVar,
Union,
)


Expand All @@ -53,6 +54,10 @@

_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
_T = TypeVar("_T")
# Not `TypeVar(..., type, Callable)` as else the output type of our first overload is
# just `type`, and not the particular class that is decorated.
_TypeOrCallable = TypeVar("_TypeOrCallable", bound=Union[type, Callable])


class _Sentinel:
Expand All @@ -77,7 +82,11 @@ def _apply_typechecker(typechecker, fn):
def jaxtyped(
*,
typechecker=_sentinel,
) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]: ...
) -> Callable[[_TypeOrCallable], _TypeOrCallable]: ...


@overload
def jaxtyped(fn: type[_T], *, typechecker=_sentinel) -> type[_T]: ...


@overload
Expand Down

0 comments on commit 2d311b5

Please sign in to comment.