Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable calling Tensor.requires_grad_() inside a functorch transform #849

Merged
merged 1 commit into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion functorch/_src/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_func_increment_nesting,
_assert_wrapped_functional,
_propagate_functional_input_mutation,
set_inplace_requires_grad_allowed,
)

argnums_t = Union[int, Tuple[int, ...]]
Expand All @@ -40,7 +41,12 @@
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
return x.requires_grad_()
try:
set_inplace_requires_grad_allowed(True)
return x.requires_grad_()
finally:
set_inplace_requires_grad_allowed(False)

raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
Expand Down
23 changes: 21 additions & 2 deletions functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,26 @@ class FuncTorchTLS : public FuncTorchTLSBase {
}

void checkSupportsInplaceRequiresGrad() const override {
// Does nothing
TORCH_CHECK(dynamicLayerStack.size() == 0 || allow_inplace_requires_grad_,
"You are attempting to call Tensor.requires_grad_() (or perhaps using ",
"torch.autograd.functional.* APIs) inside of a function being transformed ",
"by a functorch transform. ",
"This is unsupported, please attempt to use the functorch transforms ",
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() "
"outside of a function being transformed instead.");
}
void checkSupportsRetainGrad() const override {
// Does nothing
TORCH_CHECK(dynamicLayerStack.size() == 0,
"You are attempting to call Tensor.retain_grad() ",
"inside of a function being transformed ",
"by a functorch transform. ",
"This is unsupported, please attempt to use the functorch transforms ",
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() "
"outside of a function being transformed instead.");
}

std::vector<DynamicLayer> dynamicLayerStack;
bool allow_inplace_requires_grad_ = false;
};

static FuncTorchTLS* getRawFunctorchTLS() {
Expand All @@ -122,6 +135,12 @@ static FuncTorchTLS* getRawFunctorchTLS() {
return result;
}

void setInplaceRequiresGradAllowed(bool allowed) {
auto* functorch_tls = getRawFunctorchTLS();
functorch_tls->allow_inplace_requires_grad_ = allowed;
}


static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
return getRawFunctorchTLS()->dynamicLayerStack;
}
Expand Down
2 changes: 2 additions & 0 deletions functorch/csrc/DynamicLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Tensor unwrapIfDead(const Tensor& tensor);
std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);

void setInplaceRequiresGradAllowed(bool allowed);


}
} // namespace at
1 change: 1 addition & 0 deletions functorch/csrc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings");
m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed);
m.def("dlevel", &at::functorch::dlevel, "dlevel");
m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor");
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
Expand Down
87 changes: 87 additions & 0 deletions test/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,93 @@ def f(x):
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))

def test_requires_grad_inside_transform(self, device):
def f(x):
x.requires_grad_()
return x.sin().sum()

x = torch.randn(3)

with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
vmap(f)(x)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
grad(f)(x)
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
vmap(grad(f))(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
grad(grad(f))(x)

def test_retain_grad_inside_transform(self, device):
def f(x):
y = x.sin()
y.retain_grad()
return y.sum()

x = torch.randn(3)

with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"):
grad(f)(x)

def test_autograd_functional_jacrev_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x)
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)

def test_autograd_functional_vjp_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x)
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)

def test_autograd_functional_jvp_inside_transform(self, device):
def f(x):
t = torch.ones_like(x)
y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,))
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
vmap(f)(x)

x = torch.randn([])
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
grad(f)(x)

def test_autograd_functional_jacfwd_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(
lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
return y

B = 5
x = torch.randn(B, 3)
with self.assertRaises(RuntimeError):
vmap(f)(x)

x = torch.randn([])
with self.assertRaises(RuntimeError):
grad(f)(x)


class TestMakeFunctional(TestCase):
@parametrize('disable_autograd_tracking', [True, False])
Expand Down