-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch coupling #2804
Pytorch coupling #2804
Conversation
firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style, and install script.
Missing blank line.
F : pyadjoint.ReducedFunctional | ||
The reduced functional to wrap. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's rather restrictive to wrap only reduced functionals. Do you plan to extend this to arbitrary Firedrake programs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Despite the name, which is historic, ReducedFunctional can return any overloaded type. This means that this does enable wrapping of arbitrary Firedrake code. ReducedFunctional is simply the mechanism for expressing a Firedrake calculation as a function of controls, which is what you need if you're going to differentiate through it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. So it's possible to send a function that returns the result of solve(F == 0, u, bcs=bcs)
to ReducedFunctional? I was confused with the test in this PR that uses a function named solve_poisson
but returns assemble(u ** 2 * dx)
instead of just returning u
:
firedrake/tests/regression/test_pytorch_coupling.py
Lines 86 to 95 in f05d119
def solve_poisson(f, V): | |
"""Solve Poisson problem with homogeneous Dirichlet boundary conditions""" | |
u = Function(V) | |
v = TestFunction(V) | |
F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx | |
bcs = [DirichletBC(V, Constant(1.0), "on_boundary")] | |
# Solve PDE | |
solve(F == 0, u, bcs=bcs) | |
# Assemble Firedrake loss | |
return assemble(u ** 2 * dx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can. The output doesn't need to be scalar-valued as illustrated in the other test (poisson_residual
):
firedrake/tests/regression/test_pytorch_coupling.py
Lines 78 to 82 in f05d119
def poisson_residual(u, f, V): | |
"""Assemble the residual of a Poisson problem""" | |
v = TestFunction(V) | |
F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx | |
return assemble(F) |
This PR enables to embed Firedrake operations within PyTorch. It consists in:
This PR is associated with the following pyadjoint PR: dolfin-adjoint/pyadjoint#95