diff --git a/docs/jep/11859-type-annotations.md b/docs/jep/11859-type-annotations.md index 1f08928eb014..74079a170f15 100644 --- a/docs/jep/11859-type-annotations.md +++ b/docs/jep/11859-type-annotations.md @@ -1,8 +1,10 @@ -# Type Annotation Roadmap for JAXS +# Type Annotation Roadmap for JAX + _jakevdp@_ _Aug 2022_ ## Background + Python 3.0 introduced optional function annotations ([PEP 3107](https://peps.python.org/pep-3107/)), which were later codified for use in static type checking around the release of Python 3.5 ([PEP 484](https://peps.python.org/pep-0484/)). To some degree, type annotations and static type checking have become an integral part of many Python development workflows, and to this end we have added annotations in a number of places throughout the JAX API. The current state of type annotations in JAX is a bit patchwork, and efforts to add more have been hampered by more fundamental design questions. @@ -12,11 +14,14 @@ Why do we need such a roadmap? Better/more comprehensive type annotations are a In addition, we frequently receive pull requests from external users seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. This document details JAX's goals and recommendations for type annotations within the package. -## Why Type Annotations? -There are a number of reasons that a Python project might wish to annotate their code-base; I'll summarize them in this document as Level 1, Level 2, and Level 3. +## Why type annotations? + +There are a number of reasons that a Python project might wish to annotate their code-base; we'll summarize them in this document as Level 1, Level 2, and Level 3. + +### Level 1: annotations as documentation + +When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](http://google3/third_party/py/jax/_src/lax/slicing.py;l=48-54;rcl=446366018)]: -### Level 1: Annotations as Documentation -When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to Any. An example can be found in `lax/slicing.py` [[source](http://google3/third_party/py/jax/_src/lax/slicing.py;l=48-54;rcl=446366018)]: ```python Array = Any Shape = core.Shape @@ -26,22 +31,25 @@ def slice(operand: Array, start_indices: Sequence[int], strides: Optional[Sequence[int]] = None) -> Array: ... ``` + For the purposes of static type checking, this use of `Array = Any` for array annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all) but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type Any), so the documentation benefit does not go beyond the source code. +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code. A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. -### Level 2: Annotations for Intelligent Autocomplete +### Level 2: annotations for intelligent autocomplete Many modern IDEs take advantage of type annotations to improve [intelligent code completion](https://en.wikipedia.org/wiki/Intelligent_code_completion) in IDEs. One example of this is the [PyLance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance) extension for VSCode, which uses Microsofts' [pyright](https://github.com/microsoft/pyright) static type checker as a source of information for VSCode's [IntelliSense](https://code.visualstudio.com/docs/editor/intellisense) completions. This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development. JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](http://google3/third_party/py/jax/_src/random.py;l=359;rcl=445522278)]: + ```python def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray: ... ``` + In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](http://google3/third_party/py/jax/_src/numpy/ndarray.py;l=40;rcl=431999528)), and so PyLance in VSCode can offer the full set of autocompletions on results from this function. Here is the result in VSCode: ![VSCode Intellisense Screenshot](../_static/vscode-completion.png) @@ -49,7 +57,7 @@ In this case `jnp.ndarray` is an abstract base class that forward-declares the a Listed are all methods and attributes declared by the abstract class We'll discuss further below why it was necessary to create this abstract class rather than annotating with `DeviceArray` directly. -### Level 3: Annotations for Static Type-Checking +### Level 3: annotations for static type-checking Static type-checking often is the first thing people think of when considering the purpose of type annotations in Python code. While Python does not do any runtime checking of types, several mature static type checking tools exist that can do this as part of a CI test suite. @@ -68,46 +76,53 @@ These typically represent cases where typing problems have arisen; they may be i On occasion, they are due to real & subtle bugs in the behavior of pytype and/or mypy. In rare cases, they may be due to the fact that JAX uses Python patterns that are difficult or even impossible to express in terms of Python's static type annotation syntax. -## Type Annotation Challenges for JAX -As we've seen above, JAX currently has type annotations that are a mish-mash of styles, and aimed at all three areas of type annotation application above. +## Type annotation challenges for JAX + +As demonstrated above, JAX currently has type annotations that are a mish-mash of styles, and aimed at all three areas of type annotation application above. Partly, this comes from the fact that JAX's source code poses a number of unique challenges for Python's type annotation system. We'll outline them here. ### Challenge 1: pytype, mypy and developer friction + One challenge JAX faces is that package development must satisfy the constraints of two different static type checking systems, `pytype` (used by internal CI and internal Google projects) and `mypy` (used by external CI and external dependencies). -Although the two type checkers have broad overlap in their behavior, each presents its own unique corner cases, as evidenced by the numerous `#type: ignore` and `#pytype: disable` statements througout the JAX codebase. +Although the two type checkers have broad overlap in their behavior, each presents its own unique corner cases, as evidenced by the numerous `#type: ignore` and `#pytype: disable` statements throughout the JAX codebase. This creates friction in development: internal contributors may iterate until tests pass, only to find that on export their pytype-approved code falls afoul of mypy. -For external contributors, it's often the opposite: a recent example is {jax-issue}`#9596` which had to be rolled-back after it failed internal Google continous checks. +For external contributors, it's often the opposite: a recent example is {jax-issue}`#9596` which had to be rolled-back after it failed internal Google continuous checks. Each time we move a type annotation from Level 1 (`Any` everywhere) to Level 2 or 3 (stricter annotations), it introduces more of these potential sharp edges. -### Challenge 2: Array duck-typing -One particular challenge for annotating JAX code is its heavy use of duck-typing. An input to a function marked `Array` in general could be one of many different types: a jax `DeviceArray`, a numpy `np.ndarray`, a numpy scalar, a Python scalar, a Python sequence, an object with an `__array__` attribute, an object with a `__jax_array__` attribute, or any flavor of `jax.Tracer`. +### Challenge 2: array duck-typing + +One particular challenge for annotating JAX code is its heavy use of duck-typing. An input to a function marked `Array` in general could be one of many different types: a JAX `DeviceArray`, a NumPy `np.ndarray`, a NumPy scalar, a Python scalar, a Python sequence, an object with an `__array__` attribute, an object with a `__jax_array__` attribute, or any flavor of `jax.Tracer`. For this reason, simple annotations like `def func(x: jnp.ndarray)` will not be sufficient, and will lead to false positives for many valid uses. This means that type annotations for JAX functions will not be short or trivial, but we would have to effectively develop a set of JAX-specific typing extensions similar to those in the [`numpy.typing` package](https://github.com/numpy/numpy/blob/main/numpy/_typing/_array_like.py). -### Challenge 3: Transformations and decorators +### Challenge 3: transformations and decorators + JAX's Python API relies heavily on function transformations ({func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, etc.), and this type of API poses a particular challenge for static type analysis. Flexible annotation for decorators has been a [long-standing issue](https://github.com/python/mypy/issues/1927) in the mypy package, which was only recently resolved by the introduction of `ParamSpec` in [PEP 612](https://peps.python.org/pep-0612/), available starting in Python 3.10. Because JAX follows [NEP 29](https://numpy.org/neps/nep-0029-deprecation_policy.html), it cannot rely on Python 3.10 features until sometime after mid-2024. In the meantime, Protocols can be used as a partial solution to this (JAX added this for jit and other methods in {jax-issue}`#9950`) and the future approach may be possible via the `typing_extensions` package (a prototype is in {jax-issue}`#9999`) though this currently reveals fundamental bugs in mypy (see [python/mypy#12593](https://github.com/python/mypy/issues/12593)). All that to say: it's not yet clear that the API of JAX's function transforms can be suitably annotated within the current constraints of Python type annotation tools. -### Challenge 4: Array annotation lack of granularity +### Challenge 4: array annotation lack of granularity + Another challenge here is common to all Python array APIs, and has been part of the JAX discussion for several years (see {jax-issue}`#943`). Type annotations have to do with the Python class or type of an object, whereas in array-based languages often the attributes of the class are more important. In the case of NumPy, JAX, and similar packages, often we would wish to annotate particular array shapes and data types. For example, the arguments to the `jnp.linspace` function must be scalar values, but in JAX scalars are represented by zero-dimensional arrays. So in order for annotations to not raise false positives, we must allow these arguments to be *arbitrary* arrays. -Another example is the second argument to jax.random.choice, which must have `dtype=int` when `shape=()`. +Another example is the second argument to `jax.random.choice`, which must have `dtype=int` when `shape=()`. Python has a plan to enable type annotations with this level of granularity via Variadic Type Generics (see [PEP 646](https://peps.python.org/pep-0646/), slated for Python 3.11) but again due to [NEP 29](https://numpy.org/neps/nep-0029-deprecation_policy.html) this construct will not be available to JAX until mid-2025 at the earliest. There are some third-party projects that may help in the meantime, in particular [google/jaxtyping](https://github.com/google/jaxtyping), but this uses non-standard annotations and may not be suitable for annotating JAX itself. All told, the array-type-granularity challenge is less of an issue than the other challenges, because the main effect is that array-like annotations will be less specific than they otherwise could be. -### Challenge 5: Imprecise APIs inherited from NumPy +### Challenge 5: imprecise APIs inherited from NumPy + A large part of JAX’s user-facing API is inherited from NumPy within the `jax.numpy` submodule. NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the `numpy.tile` function, which is defined like this: + ```python def tile(A, reps): try: @@ -117,6 +132,7 @@ def tile(A, reps): d = len(tup) ... ``` + Here the *intent* is that `reps` would contain either an `int` or a sequence of `int` values, but the *implementation* allows `tup` to be any iterable. When adding annotations to this kind of duck-typed code, we could take one of two routes: 1. We may choose to annotate the *intent* of the function's API, which here might be something like `reps: Union[int, Sequence[int]]`. @@ -127,7 +143,8 @@ Gradual typing of an existing duck-typed API means that the current annotation i Broadly speaking, annotating intent better serves Level 1 type checking, while annotating implementation better serves Level 3, while Level 2 is more of a mixed bag (both intent and implementation are important when it comes to annotations in IDEs). -## JAX Type Annotation Roadmap +## JAX type annotation roadmap + With this framing (Level 1/2/3) and JAX-specific challenges in mind, we can begin to develop our roadmap for implementing consistent type annotations across the JAX project. For JAX type annotation, we have the following goals: @@ -135,10 +152,10 @@ For JAX type annotation, we have the following goals: 1. We would like to support full, *Level 1, 2, and 3* type annotation as far as possible. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions. 2. In order to not add undue development friction (due to the internal/external CI differences), we would like to be conservative in the type annotation constructs we use: in particular, when it comes to recently-introduced mechanisms such as `ParamSpec` (PEP [PEP 612](https://peps.python.org/pep-0612/),), we would like to wait until support in mypy and other tools stabilizes before relying on them. - One impact of this is that for the time being, when functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. This is because `ParamSpec` is still only partially supported; the PEP is slated for Python 3.10 (though it can be used before that via [typing-extensions](https://github.com/python/typing_extensions) and at the time of this writing mypy has a laundry-list of incompatibilities with the `ParamSpec`-based annotations (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)). + One impact of this is that for the time being, when functions are decorated by JAX transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. This is because `ParamSpec` is still only partially supported; the PEP is slated for Python 3.10 (though it can be used before that via [typing-extensions](https://github.com/python/typing_extensions) and at the time of this writing mypy has a laundry-list of incompatibilities with the `ParamSpec`-based annotations (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)). We will revisit this question in the future once support for such features stabilizes. -3. JAX type annotations shoudl in general indicate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape in some function implementations). +3. JAX type annotations should in general indicate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape in some function implementations). 4. Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as: @@ -147,7 +164,7 @@ For JAX type annotation, we have the following goals: - `ShapeLike` - etc. - Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in most places, so the type definition will be simpler than the numpy analog. + Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in most places, so the type definition will be simpler than the NumPy analog. 5. Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent. For this purpose, we will implement in {mod}`jax.typing` several strictly-typed analogs of the permissive types mentioned above, namely: @@ -157,4 +174,4 @@ For JAX type annotation, we have the following goals: - `NamedShape` - etc. -6. Aside from common typing protocols gathered in `jax.typing`, we should err on the side of simplicity, and avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification of the API cannot be succinctly specified. This is a comprimise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of simplicity. +6. Aside from common typing protocols gathered in `jax.typing`, we should err on the side of simplicity, and avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification of the API cannot be succinctly specified. This is a compromise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of simplicity.