Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch coupling #2804

Merged
merged 59 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
6535f0a
Add backend skeleton
nbouziani Dec 1, 2022
754f5d6
Update backend
nbouziani Feb 1, 2023
f1e37fd
Add PyTorch custom operator
nbouziani Feb 1, 2023
ce9ac31
Add HybridOperator
nbouziani Feb 1, 2023
a72a06e
Add test
nbouziani Feb 1, 2023
cd5f11d
Update backend
nbouziani Feb 1, 2023
f388b88
Fix reduced functional call in FiredrakeTorchOperator
nbouziani Feb 5, 2023
422b1d5
Add torch_op helper function
nbouziani Feb 5, 2023
0288d0a
Rename torch_op -> torch_operator
nbouziani Feb 7, 2023
e01fec5
Remove HybridOperator in favour of torch_operator
nbouziani Feb 10, 2023
22f654c
Update comments FiredrakeTorchOperator
nbouziani Feb 11, 2023
6c2072b
Update pytorch backend mappings
nbouziani Feb 11, 2023
0fa8738
Use torch tensor type
nbouziani Feb 11, 2023
0772a41
Remove squeezing when mapping from pytorch
nbouziani Mar 6, 2023
d5eca47
Fix single-precision casting of AdjFloat
nbouziani Mar 7, 2023
a0e3ea8
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 7, 2023
da5b052
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 7, 2023
e819524
Move pytorch coupling code to preconditioners/
nbouziani Mar 7, 2023
76e7f89
Clean up tests
nbouziani Mar 8, 2023
2dfb9b8
Fix lint and doc
nbouziani Mar 8, 2023
0f511d7
Add torch to test dependencies
nbouziani Mar 8, 2023
ad7a008
Remove spurious torch packages
nbouziani Mar 8, 2023
2e6c230
Remove spurious torch packages
nbouziani Mar 8, 2023
625ef3c
Handle case where PyTorch is not installed using the backend
nbouziani Mar 8, 2023
b57e6b3
Merge branch 'pytorch_coupling' of github.com:firedrakeproject/firedr…
nbouziani Mar 8, 2023
369c070
Clean up
nbouziani Mar 8, 2023
8764522
Add pyadjoint branch
nbouziani Mar 8, 2023
8bde10b
Address PR's comments
nbouziani Mar 9, 2023
d520c28
Update build.yml
nbouziani Mar 9, 2023
89ba376
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 10, 2023
97d1001
Add option to install torch
nbouziani Mar 10, 2023
125cff9
Add doc for torch installation
nbouziani Mar 10, 2023
06ea7f4
Merge master
nbouziani Mar 14, 2023
8d75d6d
Add citation
nbouziani Mar 14, 2023
1dd095c
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 15, 2023
01627ec
Remove non ascii characters
nbouziani Mar 17, 2023
d139c5a
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 17, 2023
cf6dc5a
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 26, 2023
aef4c11
Rename backend function
nbouziani Mar 29, 2023
b19eaae
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 30, 2023
a417a39
Lift torch installation consideration to the import level
nbouziani Apr 1, 2023
1ef3919
Add torch to vanilla docker container
nbouziani Apr 1, 2023
d51158a
Update Dockerfile.vanilla
nbouziani Apr 3, 2023
e5876c4
Build (test) Firedrake docker image with torch
nbouziani Apr 3, 2023
6bf1e2f
Use test container for building docs
nbouziani Apr 3, 2023
a465d93
Update docs.yml
nbouziani Apr 3, 2023
19b7b65
Merge master
nbouziani May 16, 2023
d3448bd
Revert container for docs
nbouziani May 16, 2023
b4d7f48
Fix docs
nbouziani May 16, 2023
9304bf5
Reference class in FiredrakeTorchOperator's docstring
nbouziani May 17, 2023
9fdcd1a
Add torch doc link
nbouziani May 17, 2023
7791fb5
Remove backend abstraction
nbouziani May 18, 2023
4e4d6aa
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani May 18, 2023
ef5b67b
Address comments
nbouziani May 18, 2023
78bdd89
Fix class reference in from_torch docstring
nbouziani May 18, 2023
ec74361
Remove optional in docstrings
nbouziani May 18, 2023
c0a7dcd
Update firedrake/ml/pytorch.py
dham May 18, 2023
8d04288
Address comments (part 2)
nbouziani May 18, 2023
ce8a99b
Fix docs
nbouziani May 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Build Firedrake
run: |
cd ..
./firedrake/scripts/firedrake-install $COMPLEX --venv-name firedrake_venv --tinyasm --netgen --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false)
./firedrake/scripts/firedrake-install $COMPLEX --venv-name firedrake_venv --tinyasm --torch --netgen --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
. ../firedrake_venv/bin/activate
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@
'pyadjoint': ('https://www.dolfin-adjoint.org/en/latest/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
'loopy': ('https://documen.tician.de/loopy/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
}

# -- Options for sphinxcontrib.bibtex ------------------------------------
Expand Down
1 change: 1 addition & 0 deletions firedrake/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from firedrake.ml.pytorch import FiredrakeTorchOperator, torch_operator, from_ml_backend, to_ml_backend # noqa: F401
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
160 changes: 160 additions & 0 deletions firedrake/ml/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
try:
import torch
except ImportError:
raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.")

import collections
from functools import partial

from firedrake.function import Function
from firedrake.vector import Vector
from firedrake.constant import Constant
from firedrake_citations import Citations

from pyadjoint.reduced_functional import ReducedFunctional

nbouziani marked this conversation as resolved.
Show resolved Hide resolved

Citations().add("Bouziani2023", """
@inproceedings{Bouziani2023,
title = {Physics-driven machine learning models coupling {PyTorch} and {Firedrake}},
author = {Bouziani, Nacime and Ham, David A.},
booktitle = {{ICLR} 2023 {Workshop} on {Physics} for {Machine} {Learning}},
year = {2023},
doi = {10.48550/arXiv.2303.06871}
}
""")


class FiredrakeTorchOperator(torch.autograd.Function):
"""
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F.
`FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward
dham marked this conversation as resolved.
Show resolved Hide resolved
passes by directly calling the reduced functional F.

Inputs:
metadata: dictionary used to stash Firedrake objects.
x_P: PyTorch tensors representing the inputs to the Firedrake operator F

Outputs:
y_P: PyTorch tensor representing the output of the Firedrake operator F
"""

def __init__(self):
super(FiredrakeTorchOperator, self).__init__()

# This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()')
@staticmethod
def forward(ctx, metadata, *x_P):
"""Forward pass of the PyTorch custom operator."""
F = metadata['F']
V = metadata['V_controls']
# Convert PyTorch input (i.e. controls) to Firedrake
x_F = [from_ml_backend(xi, Vi) for xi, Vi in zip(x_P, V)]
# Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape
y_F = F(x_F)
# Stash metadata to the PyTorch context
ctx.metadata.update(metadata)
# Convert Firedrake output to PyTorch
y_P = to_ml_backend(y_F)
return y_P.detach()

@staticmethod
def backward(ctx, grad_output):
"""Backward pass of the PyTorch custom operator."""
F = ctx.metadata['F']
V = ctx.metadata['V_output']
# Convert PyTorch gradient to Firedrake
adj_input = from_ml_backend(grad_output, V)
if isinstance(adj_input, Function):
adj_input = adj_input.vector()

# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
adj_output = F.derivative(adj_input=adj_input)

# Tuplify adjoint output
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output

# None is for metadata arg in `forward`
return None, *[to_ml_backend(di) for di in adj_output]


def torch_operator(F):
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
"""Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator
whose inputs and outputs are PyTorch tensors.
"""
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
Citations().register("Bouziani2023")

if not isinstance(F, ReducedFunctional):
raise ValueError("F must be a ReducedFunctional")

V_output = extract_function_space(F.functional)
V_controls = [c.control.function_space() for c in F.controls]
metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output}
F_P = partial(FiredrakeTorchOperator.apply, metadata)
return F_P


def extract_function_space(x):
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
"""Get function space out of x"""
if isinstance(x, Function):
return x.function_space()
elif isinstance(x, Vector):
return extract_function_space(x.function)
elif isinstance(x, float):
return None
else:
raise ValueError("Cannot infer the function space of %s" % x)


def to_ml_backend(x, gather=False, batched=True, **kwargs):
r"""Convert a Firedrake object `x` into a PyTorch tensor.

:arg x: Firedrake object (Function, Vector, Constant)
:kwarg gather: if True, gather data from all processes
:kwarg batched: if True, add a batch dimension to the tensor
:kwarg kwargs: additional arguments to be passed to torch.Tensor constructor
- device: device on which the tensor is allocated (default: "cpu")
- dtype: the desired data type of returned tensor (default: type of x.dat.data)
- requires_grad: if the tensor should be annotated (default: False)
"""
if isinstance(x, (Function, Vector)):
if gather:
# Gather data from all processes
x_P = torch.tensor(x.vector().gather(), **kwargs)
else:
# Use local data
x_P = torch.tensor(x.vector().get_local(), **kwargs)
if batched:
# Default behaviour: add batch dimension after converting to PyTorch
return x_P[None, :]
return x_P
elif isinstance(x, Constant):
return torch.tensor(x.values(), **kwargs)
elif isinstance(x, (float, int)):
if isinstance(x, float):
# Set double-precision
kwargs['dtype'] = torch.double
return torch.tensor(x, **kwargs)
else:
raise ValueError("Cannot convert %s to a torch tensor" % str(type(x)))


def from_ml_backend(x, V=None):
r"""Convert a PyTorch tensor `x` into a Firedrake object.

:arg x: PyTorch tensor (torch.Tensor)
:kwarg V: function space of the corresponding Function or None when `x` is to be mapped to a Constant
"""
if x.device.type != "cpu":
raise NotImplementedError("Firedrake does not support GPU tensors")

if V is None:
val = x.detach().numpy()
if val.shape == (1,):
val = val[0]
return Constant(val)
else:
x = x.detach().numpy()
x_F = Function(V, dtype=x.dtype)
x_F.vector().set_local(x)
return x_F
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@ def pytest_configure(config):
config.addinivalue_line(
"markers",
"skipcomplexnoslate: mark as skipped in complex mode due to lack of Slate")
config.addinivalue_line(
"markers",
"skiptorch: mark as skipped if PyTorch is not installed")


def pytest_collection_modifyitems(session, config, items):
from firedrake.utils import complex_mode, SLATE_SUPPORTS_COMPLEX

try:
import firedrake.ml as fd_ml
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
del fd_ml
ml_backend = True
except ImportError:
ml_backend = False

for item in items:
if complex_mode:
if item.get_closest_marker("skipcomplex") is not None:
Expand All @@ -29,6 +39,10 @@ def pytest_collection_modifyitems(session, config, items):
if item.get_closest_marker("skipreal") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense unless in complex mode"))

if not ml_backend:
if item.get_closest_marker("skiptorch") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed"))


@pytest.fixture(scope="module", autouse=True)
def check_empty_tape(request):
Expand Down
Loading