Skip to content

Commit

Permalink
Add fp8 types exposed in jax.numpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
liudangyi authored and patrick-kidger committed Nov 19, 2024
1 parent 4307e19 commit 51345dd
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
10 changes: 10 additions & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
Complex64 as Complex64,
Complex128 as Complex128,
Float as Float,
Float8e4m3b11fnuz as Float8e4m3b11fnuz,
Float8e4m3fn as Float8e4m3fn,
Float8e4m3fnuz as Float8e4m3fnuz,
Float8e5m2 as Float8e5m2,
Float8e5m2fnuz as Float8e5m2fnuz,
Float16 as Float16,
Float32 as Float32,
Float64 as Float64,
Expand Down Expand Up @@ -110,6 +115,11 @@
Complex64 as Complex64,
Complex128 as Complex128,
Float as Float,
Float8e4m3b11fnuz as Float8e4m3b11fnuz,
Float8e4m3fn as Float8e4m3fn,
Float8e4m3fnuz as Float8e4m3fnuz,
Float8e5m2 as Float8e5m2,
Float8e5m2fnuz as Float8e5m2fnuz,
Float16 as Float16,
Float32 as Float32,
Float64 as Float64,
Expand Down
20 changes: 19 additions & 1 deletion jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,12 @@ def __init_subclass__(cls, **kwargs):
_int16 = "int16"
_int32 = "int32"
_int64 = "int64"
# fp8 types exposed in Jax, see https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L92-L97
_float8_e4m3b11fnuz = "float8_e4m3b11fnuz"
_float8_e4m3fn = "float8_e4m3fn"
_float8_e4m3fnuz = "float8_e4m3fnuz"
_float8_e5m2 = "float8_e5m2"
_float8_e5m2fnuz = "float8_e5m2fnuz"
_bfloat16 = "bfloat16"
_float16 = "float16"
_float32 = "float32"
Expand Down Expand Up @@ -761,6 +767,11 @@ class _Cls(AbstractDtype):
Int16 = _make_dtype(_int16, "Int16")
Int32 = _make_dtype(_int32, "Int32")
Int64 = _make_dtype(_int64, "Int64")
Float8e4m3b11fnuz = _make_dtype(_float8_e4m3b11fnuz, "Float8e4m3b11fnuz")
Float8e4m3fn = _make_dtype(_float8_e4m3fn, "Float8e4m3fn")
Float8e4m3fnuz = _make_dtype(_float8_e4m3fnuz, "Float8e4m3fnuz")
Float8e5m2 = _make_dtype(_float8_e5m2, "Float8e5m2")
Float8e5m2fnuz = _make_dtype(_float8_e5m2fnuz, "Float8e5m2fnuz")
BFloat16 = _make_dtype(_bfloat16, "BFloat16")
Float16 = _make_dtype(_float16, "Float16")
Float32 = _make_dtype(_float32, "Float32")
Expand All @@ -771,7 +782,14 @@ class _Cls(AbstractDtype):
bools = [_bool, _bool_]
uints = [_uint4, _uint8, _uint16, _uint32, _uint64]
ints = [_int4, _int8, _int16, _int32, _int64]
floats = [_bfloat16, _float16, _float32, _float64]
float8 = [
_float8_e4m3b11fnuz,
_float8_e4m3fn,
_float8_e4m3fnuz,
_float8_e5m2,
_float8_e5m2fnuz,
]
floats = float8 + [_bfloat16, _float16, _float32, _float64]
complexes = [_complex64, _complex128]

# We match NumPy's type hierarachy in what types to provide. See the diagram at
Expand Down
5 changes: 5 additions & 0 deletions jaxtyping/_indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
Annotated as Complex64, # noqa: F401
Annotated as Complex128, # noqa: F401
Annotated as Float, # noqa: F401
Annotated as Float8e4m3b11fnuz, # noqa: F401
Annotated as Float8e4m3fn, # noqa: F401
Annotated as Float8e4m3fnuz, # noqa: F401
Annotated as Float8e5m2, # noqa: F401
Annotated as Float8e5m2fnuz, # noqa: F401
Annotated as Float16, # noqa: F401
Annotated as Float32, # noqa: F401
Annotated as Float64, # noqa: F401
Expand Down
5 changes: 5 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def test_dtypes():
Complex64,
Complex128,
Float,
Float8e4m3b11fnuz,
Float8e4m3fn,
Float8e4m3fnuz,
Float8e5m2,
Float8e5m2fnuz,
Float16,
Float32,
Float64,
Expand Down

0 comments on commit 51345dd

Please sign in to comment.