Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch coupling #2804

Merged
merged 59 commits into from
May 18, 2023
Merged

Pytorch coupling #2804

merged 59 commits into from
May 18, 2023

Conversation

nbouziani
Copy link
Contributor

This PR enables to embed Firedrake operations within PyTorch. It consists in:

  • Adding a PyTorch custom operator (analogous to ExternalOperator) to represent Firedrake operators expressed as a ReducedFunctional. Forward and backward computations are delegated to the reduced functional.
  • Adding a backend class to map from PyTorch to Firedrake and vice versa
  • Adding tests

This PR is associated with the following pyadjoint PR: dolfin-adjoint/pyadjoint#95

firedrake/pytorch_coupling/__init__.py Outdated Show resolved Hide resolved
firedrake/pytorch_coupling/pytorch_custom_operator.py Outdated Show resolved Hide resolved
firedrake/pytorch_coupling/pytorch_custom_operator.py Outdated Show resolved Hide resolved
tests/conftest.py Outdated Show resolved Hide resolved
tests/regression/test_pytorch_coupling.py Outdated Show resolved Hide resolved
Copy link
Member

@dham dham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style, and install script.

firedrake/pytorch_coupling/backends.py Outdated Show resolved Hide resolved
firedrake/pytorch_coupling/backends.py Outdated Show resolved Hide resolved
firedrake/ml/pytorch/backend.py Outdated Show resolved Hide resolved
firedrake/ml/pytorch/pytorch_custom_operator.py Outdated Show resolved Hide resolved
firedrake/ml/pytorch/__init__.py Outdated Show resolved Hide resolved
firedrake/ml/__init__.py Outdated Show resolved Hide resolved
firedrake/ml/pytorch.py Outdated Show resolved Hide resolved
firedrake/ml/pytorch.py Show resolved Hide resolved
firedrake/ml/pytorch.py Show resolved Hide resolved
firedrake/ml/pytorch.py Outdated Show resolved Hide resolved
firedrake/ml/pytorch.py Show resolved Hide resolved
firedrake/ml/pytorch.py Outdated Show resolved Hide resolved
Missing blank line.
tests/conftest.py Outdated Show resolved Hide resolved
@dham dham dismissed stale reviews from ksagiyam and connorjward May 18, 2023 15:24

addressed

@dham dham merged commit f39e129 into master May 18, 2023
@dham dham deleted the pytorch_coupling branch May 18, 2023 15:26
Comment on lines +98 to +99
F : pyadjoint.ReducedFunctional
The reduced functional to wrap.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's rather restrictive to wrap only reduced functionals. Do you plan to extend this to arbitrary Firedrake programs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite the name, which is historic, ReducedFunctional can return any overloaded type. This means that this does enable wrapping of arbitrary Firedrake code. ReducedFunctional is simply the mechanism for expressing a Firedrake calculation as a function of controls, which is what you need if you're going to differentiate through it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So it's possible to send a function that returns the result of solve(F == 0, u, bcs=bcs) to ReducedFunctional? I was confused with the test in this PR that uses a function named solve_poisson but returns assemble(u ** 2 * dx) instead of just returning u :

def solve_poisson(f, V):
"""Solve Poisson problem with homogeneous Dirichlet boundary conditions"""
u = Function(V)
v = TestFunction(V)
F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx
bcs = [DirichletBC(V, Constant(1.0), "on_boundary")]
# Solve PDE
solve(F == 0, u, bcs=bcs)
# Assemble Firedrake loss
return assemble(u ** 2 * dx)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can. The output doesn't need to be scalar-valued as illustrated in the other test (poisson_residual):

def poisson_residual(u, f, V):
"""Assemble the residual of a Poisson problem"""
v = TestFunction(V)
F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx
return assemble(F)

@nbouziani nbouziani mentioned this pull request Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants