Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend tensorflow.compat.v1 supports forward-mode automatic differentiation via double backwards trick #1614

Merged
merged 15 commits into from
Jan 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@ def __call__(self, i=None, j=None):

# Compute J[:, j]
if j not in self.J:
if backend_name in [
"tensorflow.compat.v1",
"paddle",
]:
# TODO: Other backends
raise NotImplementedError(
"Backend f{backend_name} doesn't support forward-mode autodiff."
if backend_name == "tensorflow.compat.v1":
# We use the double backwards trick to compute the jvp of a function in
# backend tensorflow.compat.v1, because autodiff.ForwardAccumulator is
# not supported. We note that this is not the exact jvp.
tangent = tf.one_hot([j], depth=self.xs.shape[1]) * tf.ones_like(
self.xs
)
u = tf.ones_like(self.ys)
g = tf.gradients(self.ys, self.xs, grad_ys=u)
self.J[j] = tf.gradients(g, u, grad_ys=tangent)[0]
elif backend_name == "tensorflow":
# We use tensorflow.autodiff.ForwardAccumulator to compute the jvp of
# a function.
# TODO: create the tangent in a smarter way
tangent = tf.one_hot(self.xs.shape[0] * [j], depth=self.xs.shape[1])
tangent = tf.one_hot([j], depth=self.xs.shape[1]) * tf.ones_like(
self.xs
)

def grad_fn(x):
with tf.autodiff.ForwardAccumulator(
Expand Down Expand Up @@ -72,13 +75,20 @@ def grad_fn(x):
tangent = jax.numpy.zeros(self.dim_x).at[j].set(1)
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)
elif backend_name == "paddle":
# TODO: Other backends
raise NotImplementedError(
"Backend f{backend_name} doesn't support forward-mode autodiff."
)

if i is None or self.dim_y == 1:
return self.J[j]

# Compute J[i, j]
if (i, j) not in self.J:
if backend_name in ["tensorflow", "pytorch", "jax"]:
if backend_name == "tensorflow.compat.v1":
self.J[i, j] = self.J[j][:, i : i + 1]
elif backend_name in ["tensorflow", "pytorch", "jax"]:
# In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array
# and a callable is returned, so that it is consistent with the argument,
# which is also a tuple. This is useful for further computation, e.g.,
Expand Down