Skip to content

Commit

Permalink
[functorch] Disable calling Tensor.requires_grad_() inside a functorc…
Browse files Browse the repository at this point in the history
…h transform (pytorch/functorch#849)

Fixes pytorch/functorch#847

We do not allow users to call requires_grad_() inside a functorch
transform. This is because the user is effectively saying
"hey, I want another layer of autograd if I call requires_grad_()", but
that doesn't actually work because to set up a layer of autograd we need
to do some work (e.g. push autograd onto the DynamicLayerStack).

Instead, when a user calls requires_grad_() (and similarly retain_grad),
we raise a nice error message.

This has the intended consequence of causing
torch.autograd.functional.{jvp, vjp, jacobian} to error out when called
inside of a functorch transform. Users should use the functorch
equivalent.

Test Plan:
- added tests
  • Loading branch information
zou3519 committed Jul 20, 2022
1 parent be92a01 commit eb292d7
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 3 deletions.
8 changes: 7 additions & 1 deletion functorch/functorch/_src/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_func_increment_nesting,
_assert_wrapped_functional,
_propagate_functional_input_mutation,
set_inplace_requires_grad_allowed,
)

argnums_t = Union[int, Tuple[int, ...]]
Expand All @@ -40,7 +41,12 @@
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
return x.requires_grad_()
try:
set_inplace_requires_grad_allowed(True)
return x.requires_grad_()
finally:
set_inplace_requires_grad_allowed(False)

raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
Expand Down
23 changes: 21 additions & 2 deletions functorch/functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,26 @@ class FuncTorchTLS : public FuncTorchTLSBase {
}

void checkSupportsInplaceRequiresGrad() const override {
// Does nothing
TORCH_CHECK(dynamicLayerStack.size() == 0 || allow_inplace_requires_grad_,
"You are attempting to call Tensor.requires_grad_() (or perhaps using ",
"torch.autograd.functional.* APIs) inside of a function being transformed ",
"by a functorch transform. ",
"This is unsupported, please attempt to use the functorch transforms ",
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() "
"outside of a function being transformed instead.");
}
void checkSupportsRetainGrad() const override {
// Does nothing
TORCH_CHECK(dynamicLayerStack.size() == 0,
"You are attempting to call Tensor.retain_grad() ",
"inside of a function being transformed ",
"by a functorch transform. ",
"This is unsupported, please attempt to use the functorch transforms ",
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() "
"outside of a function being transformed instead.");
}

std::vector<DynamicLayer> dynamicLayerStack;
bool allow_inplace_requires_grad_ = false;
};

static FuncTorchTLS* getRawFunctorchTLS() {
Expand All @@ -122,6 +135,12 @@ static FuncTorchTLS* getRawFunctorchTLS() {
return result;
}

void setInplaceRequiresGradAllowed(bool allowed) {
auto* functorch_tls = getRawFunctorchTLS();
functorch_tls->allow_inplace_requires_grad_ = allowed;
}


static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
return getRawFunctorchTLS()->dynamicLayerStack;
}
Expand Down
2 changes: 2 additions & 0 deletions functorch/functorch/csrc/DynamicLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Tensor unwrapIfDead(const Tensor& tensor);
std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);

void setInplaceRequiresGradAllowed(bool allowed);


}
} // namespace at
1 change: 1 addition & 0 deletions functorch/functorch/csrc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings");
m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed);
m.def("dlevel", &at::functorch::dlevel, "dlevel");
m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor");
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
Expand Down
87 changes: 87 additions & 0 deletions functorch/test/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,93 @@ def f(x):
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))

def test_requires_grad_inside_transform(self, device):
def f(x):
x.requires_grad_()
return x.sin().sum()

x = torch.randn(3)

with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
vmap(f)(x)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
grad(f)(x)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
vmap(grad(f))(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
grad(grad(f))(x)

def test_retain_grad_inside_transform(self, device):
def f(x):
y = x.sin()
y.retain_grad()
return y.sum()

x = torch.randn(3)

with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"):
grad(f)(x)

def test_autograd_functional_jacrev_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x)
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)

def test_autograd_functional_vjp_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x)
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)

def test_autograd_functional_jvp_inside_transform(self, device):
def f(x):
t = torch.ones_like(x)
y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,))
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)

def test_autograd_functional_jacfwd_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(
lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaises(RuntimeError):
vmap(f)(x)

x = torch.randn([])
with self.assertRaises(RuntimeError):
grad(f)(x)


class TestMakeFunctional(TestCase):
@parametrize('disable_autograd_tracking', [True, False])
Expand Down

0 comments on commit eb292d7

Please sign in to comment.