Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch coupling #2804

Merged
merged 59 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
6535f0a
Add backend skeleton
nbouziani Dec 1, 2022
754f5d6
Update backend
nbouziani Feb 1, 2023
f1e37fd
Add PyTorch custom operator
nbouziani Feb 1, 2023
ce9ac31
Add HybridOperator
nbouziani Feb 1, 2023
a72a06e
Add test
nbouziani Feb 1, 2023
cd5f11d
Update backend
nbouziani Feb 1, 2023
f388b88
Fix reduced functional call in FiredrakeTorchOperator
nbouziani Feb 5, 2023
422b1d5
Add torch_op helper function
nbouziani Feb 5, 2023
0288d0a
Rename torch_op -> torch_operator
nbouziani Feb 7, 2023
e01fec5
Remove HybridOperator in favour of torch_operator
nbouziani Feb 10, 2023
22f654c
Update comments FiredrakeTorchOperator
nbouziani Feb 11, 2023
6c2072b
Update pytorch backend mappings
nbouziani Feb 11, 2023
0fa8738
Use torch tensor type
nbouziani Feb 11, 2023
0772a41
Remove squeezing when mapping from pytorch
nbouziani Mar 6, 2023
d5eca47
Fix single-precision casting of AdjFloat
nbouziani Mar 7, 2023
a0e3ea8
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 7, 2023
da5b052
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 7, 2023
e819524
Move pytorch coupling code to preconditioners/
nbouziani Mar 7, 2023
76e7f89
Clean up tests
nbouziani Mar 8, 2023
2dfb9b8
Fix lint and doc
nbouziani Mar 8, 2023
0f511d7
Add torch to test dependencies
nbouziani Mar 8, 2023
ad7a008
Remove spurious torch packages
nbouziani Mar 8, 2023
2e6c230
Remove spurious torch packages
nbouziani Mar 8, 2023
625ef3c
Handle case where PyTorch is not installed using the backend
nbouziani Mar 8, 2023
b57e6b3
Merge branch 'pytorch_coupling' of github.com:firedrakeproject/firedr…
nbouziani Mar 8, 2023
369c070
Clean up
nbouziani Mar 8, 2023
8764522
Add pyadjoint branch
nbouziani Mar 8, 2023
8bde10b
Address PR's comments
nbouziani Mar 9, 2023
d520c28
Update build.yml
nbouziani Mar 9, 2023
89ba376
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 10, 2023
97d1001
Add option to install torch
nbouziani Mar 10, 2023
125cff9
Add doc for torch installation
nbouziani Mar 10, 2023
06ea7f4
Merge master
nbouziani Mar 14, 2023
8d75d6d
Add citation
nbouziani Mar 14, 2023
1dd095c
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 15, 2023
01627ec
Remove non ascii characters
nbouziani Mar 17, 2023
d139c5a
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 17, 2023
cf6dc5a
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 26, 2023
aef4c11
Rename backend function
nbouziani Mar 29, 2023
b19eaae
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 30, 2023
a417a39
Lift torch installation consideration to the import level
nbouziani Apr 1, 2023
1ef3919
Add torch to vanilla docker container
nbouziani Apr 1, 2023
d51158a
Update Dockerfile.vanilla
nbouziani Apr 3, 2023
e5876c4
Build (test) Firedrake docker image with torch
nbouziani Apr 3, 2023
6bf1e2f
Use test container for building docs
nbouziani Apr 3, 2023
a465d93
Update docs.yml
nbouziani Apr 3, 2023
19b7b65
Merge master
nbouziani May 16, 2023
d3448bd
Revert container for docs
nbouziani May 16, 2023
b4d7f48
Fix docs
nbouziani May 16, 2023
9304bf5
Reference class in FiredrakeTorchOperator's docstring
nbouziani May 17, 2023
9fdcd1a
Add torch doc link
nbouziani May 17, 2023
7791fb5
Remove backend abstraction
nbouziani May 18, 2023
4e4d6aa
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani May 18, 2023
ef5b67b
Address comments
nbouziani May 18, 2023
78bdd89
Fix class reference in from_torch docstring
nbouziani May 18, 2023
ec74361
Remove optional in docstrings
nbouziani May 18, 2023
c0a7dcd
Update firedrake/ml/pytorch.py
dham May 18, 2023
8d04288
Address comments (part 2)
nbouziani May 18, 2023
ce8a99b
Fix docs
nbouziani May 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
python $(which firedrake-clean)
python -m pip install pytest-cov pytest-timeout pytest-xdist pytest-timeout
python -m pip list
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test Firedrake
run: |
. ../build/bin/activate
Expand Down
1 change: 1 addition & 0 deletions firedrake/preconditioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from firedrake.preconditioners.hypre_ads import * # noqa: F401
from firedrake.preconditioners.fdm import * # noqa: F401
from firedrake.preconditioners.facet_split import * # noqa: F401
from firedrake.preconditioners.pytorch_coupling import * # noqa: F401
2 changes: 2 additions & 0 deletions firedrake/preconditioners/pytorch_coupling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .backends import get_backend # noqa: F401
from .pytorch_custom_operator import torch_operator # noqa: F401
105 changes: 105 additions & 0 deletions firedrake/preconditioners/pytorch_coupling/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np

from firedrake.function import Function
from firedrake.vector import Vector
from firedrake.constant import Constant

import firedrake.utils as utils


class AbstractMLBackend(object):

def backend(self):
raise NotImplementedError

def to_ml_backend(self, x):
"""Convert from Firedrake to ML backend
x: Firedrake object
"""
raise NotImplementedError

def from_ml_backend(self, x, V):
"""Convert from ML backend to Firedrake
x: ML backend object
"""
raise NotImplementedError

def get_function_space(self, x):
"""Get function space out of x"""
if isinstance(x, Function):
return x.function_space()
elif isinstance(x, Vector):
return self.get_function_space(x.function)
elif isinstance(x, float):
return None
else:
raise ValueError("Cannot infer the function space of %s" % x)


class PytorchBackend(AbstractMLBackend):

nbouziani marked this conversation as resolved.
Show resolved Hide resolved
@utils.cached_property
def backend(self):
try:
import torch
except ImportError:
raise ImportError("Error when trying to import PyTorch")
return torch

@utils.cached_property
def custom_operator(self):
from firedrake.preconditioners.pytorch_coupling.pytorch_custom_operator import FiredrakeTorchOperator
return FiredrakeTorchOperator().apply

def to_ml_backend(self, x, gather=False, batched=True, **kwargs):
""" Convert a Firedrake object `x` into a PyTorch tensor

x: Firedrake object (Function, Vector, Constant)
gather: if True, gather data from all processes
batched: if True, add a batch dimension to the tensor
kwargs: additional arguments to be passed to torch.Tensor constructor
- device: device on which the tensor is allocated (default: "cpu")
- dtype: the desired data type of returned tensor (default: type of x.dat.data)
- requires_grad: if the tensor should be annotated (default: False)
"""
if isinstance(x, (Function, Vector)):
# State counter: get_local does a copy and increase the state counter while gather does not.
# We probably always want to increase the state counter and therefore should do something for the gather case
if gather:
# Gather data from all processes
x_P = self.backend.tensor(x.vector().gather(), **kwargs)
else:
# Use local data
x_P = self.backend.tensor(x.vector().get_local(), **kwargs)
if batched:
# Default behaviour: add batch dimension after converting to PyTorch
return x_P[None, :]
return x_P
elif isinstance(x, Constant):
return self.backend.tensor(x.values(), **kwargs)
elif isinstance(x, (float, int)):
return self.backend.tensor(np.array(x), **kwargs)
else:
raise ValueError("Cannot convert %s to a torch tensor" % str(type(x)))

def from_ml_backend(self, x, V=None, gather=False):
if x.device.type != "cpu":
raise NotImplementedError("Firedrake does not support GPU tensors")

if V is None:
val = x.detach().numpy()
if val.shape == (1,):
val = val[0]
return Constant(val)
else:
x = x.detach().numpy()
x_F = Function(V, dtype=x.dtype)
x_F.vector().set_local(x)
return x_F


def get_backend(backend_name='pytorch'):
if backend_name == 'pytorch':
return PytorchBackend()
else:
raise NotImplementedError("The backend: %s is not supported." % backend_name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import collections
from functools import partial

from firedrake.preconditioners.pytorch_coupling import get_backend
from firedrake.function import Function

from pyadjoint.reduced_functional import ReducedFunctional


backend = get_backend('pytorch')
nbouziani marked this conversation as resolved.
Show resolved Hide resolved


class FiredrakeTorchOperator(backend.backend.autograd.Function):
"""
PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F.
`FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward
passes by directly calling the reduced functional F.

Inputs:
metadata: dictionary used to stash Firedrake objects.
*ω: PyTorch tensors representing the inputs to the Firedrake operator F

Outputs:
y: PyTorch tensor representing the output of the Firedrake operator F
"""

# This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()')
@staticmethod
def forward(ctx, metadata, *ω):
"""Forward pass of the PyTorch custom operator."""
F = metadata['F']
V = metadata['V_controls']
# Convert PyTorch input (i.e. controls) to Firedrake
ω_F = [backend.from_ml_backend(ωi, Vi) for ωi, Vi in zip(ω, V)]
# Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape
y_F = F(ω_F)
# Stash metadata to the PyTorch context
ctx.metadata.update(metadata)
# Convert Firedrake output to PyTorch
y = backend.to_ml_backend(y_F)
return y.detach()

@staticmethod
def backward(ctx, grad_output):
"""Backward pass of the PyTorch custom operator."""
F = ctx.metadata['F']
V = ctx.metadata['V_output']
# Convert PyTorch gradient to Firedrake
adj_input = backend.from_ml_backend(grad_output, V)
if isinstance(adj_input, Function):
adj_input = adj_input.vector()

# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
Δω = F.derivative(adj_input=adj_input)

# Tuplify adjoint output
Δω = (Δω,) if not isinstance(Δω, collections.abc.Sequence) else Δω

# None is for metadata arg in `forward`
return None, *[backend.to_ml_backend(Δωi) for Δωi in Δω]


def torch_operator(F):
"""Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator
whose inputs and outputs are PyTorch tensors.
"""
if not isinstance(F, ReducedFunctional):
raise ValueError("F must be a ReducedFunctional")

V_output = backend.get_function_space(F.functional)
V_controls = [c.control.function_space() for c in F.controls]
metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output}
φ = partial(backend.custom_operator, metadata)
return φ
Loading