Current as of 9/15/2020
TFP supports alternative numerical backends to TensorFlow, including both JAX and NumPy. The intent of this document is to explain some of the details of supporting disparate backends, and how we have handled them, with the objective of helping contributors read / write code, and understand how these alternative-substrate packages are assembled.
In
tensorflow_probability/python/internal/backend
we find the implementations of both the NumPy and JAX backends. These imitate
the portion of the TensorFlow API used by TFP, but are implemented in terms of
the corresponding substrate's primitives.
Since JAX provides jax.numpy
, we are in most cases able to write a single
NumPy implementation under
tensorflow_probability/python/internal/backend
,
then use a rewrite script (found at
tensorflow_probability/python/internal/backend/jax/rewrite.py
)
to generate a JAX variant at bazel build
time. See the genrules in
tensorflow_probability/python/internal/backend/jax/BUILD
for more details of how this occurs. In cases where JAX provides a different API
(e.g. random
) or a more performant API (e.g. batched matrix decompositions,
vmap, etc.), we will special-case using an if JAX_MODE:
block.
-
Shapes
In TensorFlow, the
shape
attribute of aTensor
is atf.TensorShape
object, whereas in JAX and NumPy, it is a simpletuple
ofint
s. To handle cases for both systems nicely, we usefrom tensorflow_probability.python.internal import tensorshape_util
and move special-casing into this library. -
DTypes
TensorFlow, and to some extent JAX, use the dtype of inputs to infer the dtype of outputs, and generally try to preserve the
Tensor
dtype in a binary op betweenTensor
and non-Tensor
. NumPy, on the other hand, aggressively pushes dtypes toward 64-bit when left unspecified. In some cases debugging the JAX substrate, we have seen issues with dtypes changing from float32 to float64 or vice versa across the iterations of a while loop. Finding where this change happens can be tricky. Where possible, we aim to fix these in theinternal/backend
package, as opposed to implementation files. -
Shapes, again (JAX "omnistaging" /
prefer_static
)Every JAX primitive observed within a JIT or JAX control flow context becomes an abstract Tracer. This is similar to
@tf.function
. The main challenge this introduces for TFP is that TF allows dynamic shapes asTensor
s whereas JAX (being built atop XLA) needs shapes to be statically available (i.e. atuple
ornumpy.ndarray
, not a JAXndarray
).If you observe issues with shapes derived from
Tracer
s in JAX, often a simple fix isfrom tensorflow_probability.python.internal import prefer_static as ps
followed by replacingtf.shape
withps.shape
, and similar for other ops such astf.rank
,tf.size
,tf.concat
(when dealing with shapes), (the args to)tf.range
, etc. It's also useful to be aware ofps.convert_to_shape_tensor
, which behaves liketf.convert_to_tensor
for TF, but leaves things asnp.ndarray
for JAX. Similarly, in constructors, use theas_shape_tensor=True
arg totensor_util.convert_nonref_to_tensor
for shape-related values. -
tf.GradientTape
TF uses a tape to record ops for later gradient evaluation, whereas JAX rewrites a function while tracing its execution. Since the function transform is more general, we aim to replace usage of
GradientTape
(in tests, TFP impl, etc), withtfp.math.value_and_gradient
or similar. Then, we can special-caseJAX_MODE
inside the body ofvalue_and_gradient
. -
tf.Variable
,tf_keras.optimizers.Optimizer
TF provides a
Variable
abstraction so that graph functions may modify state, including using the KerasOptimizer
subclasses likeAdam
. JAX, in contrast, operates only on pure functions. In general, TFP is fairly functional (e.g.tfp.optimizer.lbfgs_minimize
), but in some cases (e.g.tfp.vi.fit_surrogate_posterior
,tfp.optimizer.StochasticGradientLangevinDynamics
) we have felt the mismatch too strong to try to port code to JAX. Some approaches to hoisting state out of a stateful function can be seen in the TFP spinoff projectoryx
. -
Custom derivatives
JAX supports both forward and reverse mode autodifferentiation, and where possible TFP aims to support both in JAX. To do so, in places where we define a custom derivative, we use an internal wrapper which provides a function decorator that supports both TF and JAX's interfaces for custom derivatives, namely:
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient def _f(..): pass def _f_fwd(..): return _f(..), bwd_auxiliary_data def _f_bwd(bwd_auxiliary_data, dy): pass def _f_jvp(primals, tangents): return _f(*primals), df(primals, tangents) @tfp_custom_gradient.custom_gradient(vjp_fwd=_f_fwd, vjp_bwd=_f_bwd, jvp_fn=_f_jvp) def f(..): return _f(..)
For more information, the JAX custom derivatives doc can be useful.
-
Randomness
In TF, we support both "stateful" (i.e. some latent memory tracks the state of the sampler) and "stateless" sampling. JAX natively supports only stateless, for functional purity reasons. For internal use, we have
from tensorflow_probability.python.internal import samplers
, a library that provides methods to:- convert stateful seeds to stateless, add salts (
sanitize_seed
) - split stateless seeds to multiple descendant seeds (
split_seed
) - proxy through to a number of stateless samplers (
normal
,uniform
, ...)
When the rewrite script is dealing with a
..._test.py
file, it will rewrite calls totf.random.{uniform,normal,...}
totf.random.stateless_{uniform,normal,...}
to ensure compatibility with the JAX backend, which only implements the stateless samplers. - convert stateful seeds to stateless, add salts (
In a couple cases, we commit into the repository script-munged source from
TensorFlow. These files can be found under
tensorflow_probability/python/internal/backend/numpy/gen
.
They currently include:
- an implementation of
tf.TensorShape
- several parts of
tf.linalg
, especially thetf.linalg.LinearOperator
classes
The actual rewriting is accomplished by scripts found under
tensorflow_probability/python/internal/backend/meta
,
namely gen_linear_operators.py
and gen_tensor_shape.py
.
The test
tensorflow_probability/python/internal/backend/numpy/rewrite_equivalence_test.py
verifies that the files in TensorFlow, when rewritten, match the files in the
gen/
directory. The test uses BUILD
dependencies on genrule
s that apply
the rewrite scripts, and compares those genrule inputs to the source of the
files under the gen/
directory.
Similar to the sources in internal/backend/numpy
, the sources in
internal/backend/numpy/gen
are rewritten by jax/rewrite.py
. Note that the
files under gen/
do not have the numpy
import rewritten. This is because we
only want to rewrite TensorFlow usage of TensorFlow-ported code; typically when
TF code is using numpy
, it is munging shapes, and JAX does not like shapes to
be munged using jax.numpy
(must use plain numpy
).
With internal/backend/{numpy,jax}
now ready to provide a tf2jax
or
tf2numpy
backend, we can proceed to the core packages of TFP.
The script
tensorflow_probability/substrates/meta/rewrite.py
runs on TFP sources to auto-generate JAX and NumPy python source corresponding
to the given TF source.
The most important job of the rewrite script is to rewrite import tensorflow.compat.v2 as tf
to from tensorflow_probability.python.internal.backend.jax import v2 as tf
. Second to
that, the script will rewrite dependencies on TFP subpackages to dependencies on
the corresponding substrate-specific TFP subpackages. For example, the line
from tensorflow_probability.python import math as tfp_math
becomes from tensorflow_probability.substrates.jax import math as tfp_math
. Beyond that,
there are a number of peripheral replacements to work around other wrinkles
we've accumulated over time.
In rare cases we will put an explicit if JAX_MODE:
or if NUMPY_MODE:
block
into the implementation code of a TFP submodule. This should be very uncommon.
Whenever possible, the intent is for such special-casing to live under
python/internal
. For example, today we see in bijectors/softplus.py
:
# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:
@tf.custom_gradient
def _stable_grad_softplus(x): # ...
Note that this rewrite currently adds exactly a 10-line header, so line numbers from stack traces will be +10 from the raw code.
tensorflow_probability/python/build_defs.bzl
defines a pair of bazel
build rules: multi_substrate_py_library
and
multi_substrate_py_test
.
These rules automatically invoke
tensorflow_probability/substrates/meta/rewrite.py
to emit JAX/NumPy source variants. The file bijectors/softplus.py
gets
rewritten into bijectors/_generated_jax_softplus.py
(you can view the output
under the corresponding bazel-genfiles
directory).
These build rules are also responsible for rewriting TFP-internal deps
to the
some_dep.jax
or some_dep.numpy
substrate-specific replacement.
The multi_substrate_py_library
will emit three targets: a TF py_library
with
the name given by the name
argument, a JAX py_library
with name name + '.jax'
, and a NumPy py_library
with name name + '.numpy'
.
The multi_substrate_py_test
will emit three targets, each of name + '.tf'
,
name + '.jax'
, and name + '.numpy'
. Rules specified by the
disabled_substrates
arg will not have BUILD rules emitted at all; jax_tags
and numpy_tags
can be used to specify specific tags that drop CI coverage
while keeping the target buildable and testable. The distinction is useful so
that we can track cases where we think a test should be fixable, but we haven't
yet, as opposed to cases like HMC where we know the test will never pass for
NumPy so we prefer to not even have the test target. All emitted test targets
are aggregated into a test_suite
with name corresponding to the original
name
arg.
In cases where we know we will never be able to support a given feature, the
substrates_omit_deps
, jax_omit_deps
, and numpy_omit_deps
args to
multi_substrate_py_library
can be used to exclude things. Examples include
non-pure code or code using tf.Variable
(JAX wants pure functions), or HMC (no
gradients in NumPy!). When rewriting an __init__.py
file, the rewrite script
is set up to comment out imports and __all__
lines corresponding to the
omitted deps.
In order to test against the same directory hierarchy as we use for wheel
packaging, the multi_substrate_py_library
does some internal gymnastics with a
custom bazel rule
which is able to add symlinks into
tensorflow_probability/substrates
pointing to that point to implementation files generated under
bazel-genfiles/tensorflow_probability/python
(details in
_substrate_runfiles_symlinks_impl
of build_defs.bzl
).
When it comes to building the wheel, we must first use cp -L
to resolve the
symlinks added as part of the bazel build
. Otherwise the wheel does not follow
them and fails to include tfp.substrates
. This cp -L
command sits in
pip_pkg.sh
(currently adjacent to this doc).
A couple of integration tests sit in
tensorflow_probability/substrates/meta/jax_integration_test.py
and
tensorflow_probability/substrates/meta/numpy_integration_test.py
.
We run these under CI after building and installing a wheel to verify that the
tfp.substrates
packages load correctly and do not require a tensorflow
install.