Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch coupling #2804

Merged
merged 59 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
6535f0a
Add backend skeleton
nbouziani Dec 1, 2022
754f5d6
Update backend
nbouziani Feb 1, 2023
f1e37fd
Add PyTorch custom operator
nbouziani Feb 1, 2023
ce9ac31
Add HybridOperator
nbouziani Feb 1, 2023
a72a06e
Add test
nbouziani Feb 1, 2023
cd5f11d
Update backend
nbouziani Feb 1, 2023
f388b88
Fix reduced functional call in FiredrakeTorchOperator
nbouziani Feb 5, 2023
422b1d5
Add torch_op helper function
nbouziani Feb 5, 2023
0288d0a
Rename torch_op -> torch_operator
nbouziani Feb 7, 2023
e01fec5
Remove HybridOperator in favour of torch_operator
nbouziani Feb 10, 2023
22f654c
Update comments FiredrakeTorchOperator
nbouziani Feb 11, 2023
6c2072b
Update pytorch backend mappings
nbouziani Feb 11, 2023
0fa8738
Use torch tensor type
nbouziani Feb 11, 2023
0772a41
Remove squeezing when mapping from pytorch
nbouziani Mar 6, 2023
d5eca47
Fix single-precision casting of AdjFloat
nbouziani Mar 7, 2023
a0e3ea8
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 7, 2023
da5b052
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 7, 2023
e819524
Move pytorch coupling code to preconditioners/
nbouziani Mar 7, 2023
76e7f89
Clean up tests
nbouziani Mar 8, 2023
2dfb9b8
Fix lint and doc
nbouziani Mar 8, 2023
0f511d7
Add torch to test dependencies
nbouziani Mar 8, 2023
ad7a008
Remove spurious torch packages
nbouziani Mar 8, 2023
2e6c230
Remove spurious torch packages
nbouziani Mar 8, 2023
625ef3c
Handle case where PyTorch is not installed using the backend
nbouziani Mar 8, 2023
b57e6b3
Merge branch 'pytorch_coupling' of github.com:firedrakeproject/firedr…
nbouziani Mar 8, 2023
369c070
Clean up
nbouziani Mar 8, 2023
8764522
Add pyadjoint branch
nbouziani Mar 8, 2023
8bde10b
Address PR's comments
nbouziani Mar 9, 2023
d520c28
Update build.yml
nbouziani Mar 9, 2023
89ba376
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 10, 2023
97d1001
Add option to install torch
nbouziani Mar 10, 2023
125cff9
Add doc for torch installation
nbouziani Mar 10, 2023
06ea7f4
Merge master
nbouziani Mar 14, 2023
8d75d6d
Add citation
nbouziani Mar 14, 2023
1dd095c
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 15, 2023
01627ec
Remove non ascii characters
nbouziani Mar 17, 2023
d139c5a
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 17, 2023
cf6dc5a
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 26, 2023
aef4c11
Rename backend function
nbouziani Mar 29, 2023
b19eaae
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani Mar 30, 2023
a417a39
Lift torch installation consideration to the import level
nbouziani Apr 1, 2023
1ef3919
Add torch to vanilla docker container
nbouziani Apr 1, 2023
d51158a
Update Dockerfile.vanilla
nbouziani Apr 3, 2023
e5876c4
Build (test) Firedrake docker image with torch
nbouziani Apr 3, 2023
6bf1e2f
Use test container for building docs
nbouziani Apr 3, 2023
a465d93
Update docs.yml
nbouziani Apr 3, 2023
19b7b65
Merge master
nbouziani May 16, 2023
d3448bd
Revert container for docs
nbouziani May 16, 2023
b4d7f48
Fix docs
nbouziani May 16, 2023
9304bf5
Reference class in FiredrakeTorchOperator's docstring
nbouziani May 17, 2023
9fdcd1a
Add torch doc link
nbouziani May 17, 2023
7791fb5
Remove backend abstraction
nbouziani May 18, 2023
4e4d6aa
Merge remote-tracking branch 'origin/master' into pytorch_coupling
nbouziani May 18, 2023
ef5b67b
Address comments
nbouziani May 18, 2023
78bdd89
Fix class reference in from_torch docstring
nbouziani May 18, 2023
ec74361
Remove optional in docstrings
nbouziani May 18, 2023
c0a7dcd
Update firedrake/ml/pytorch.py
dham May 18, 2023
8d04288
Address comments (part 2)
nbouziani May 18, 2023
ce8a99b
Fix docs
nbouziani May 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ jobs:
- name: Build Firedrake
run: |
cd ..
./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false)
./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager --package-branch pyadjoint adjoint-1-forms || (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
. ../build/bin/activate
python $(which firedrake-clean)
python -m pip install pytest-cov pytest-timeout pytest-xdist pytest-timeout
python -m pip list
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test Firedrake
run: |
. ../build/bin/activate
Expand Down
1 change: 1 addition & 0 deletions firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from firedrake.ensemble import *
from firedrake.randomfunctiongen import *
from firedrake.progress_bar import ProgressBar # noqa: F401
from firedrake.pytorch_coupling import *

from firedrake.logging import *
# Set default log level
Expand Down
2 changes: 2 additions & 0 deletions firedrake/pytorch_coupling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from firedrake.pytorch_coupling.backends import get_backend # noqa: F401
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
from firedrake.pytorch_coupling.pytorch_custom_operator import torch_operator # noqa: F401
120 changes: 120 additions & 0 deletions firedrake/pytorch_coupling/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from firedrake.function import Function
from firedrake.vector import Vector
from firedrake.constant import Constant

import firedrake.utils as utils


class AbstractMLBackend(object):

def backend(self):
raise NotImplementedError

def to_ml_backend(self, x):
r"""Convert from Firedrake to ML backend.

:arg x: Firedrake object
"""
raise NotImplementedError

def from_ml_backend(self, x, V):
r"""Convert from ML backend to Firedrake.

:arg x: ML backend object
"""
raise NotImplementedError

def get_function_space(self, x):
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
"""Get function space out of x"""
if isinstance(x, Function):
return x.function_space()
elif isinstance(x, Vector):
return self.get_function_space(x.function)
elif isinstance(x, float):
return None
else:
raise ValueError("Cannot infer the function space of %s" % x)


class PytorchBackend(AbstractMLBackend):

def __init__(self):
try:
import torch
self._backend = torch
except ImportError:
self._backend = None

@property
def backend(self):
if self:
return self._backend
else:
raise ImportError("Error when trying to import PyTorch.")

def __bool__(self):
return self._backend is not None

@utils.cached_property
def custom_operator(self):
from firedrake.pytorch_coupling.pytorch_custom_operator import FiredrakeTorchOperator
return FiredrakeTorchOperator().apply

def to_ml_backend(self, x, gather=False, batched=True, **kwargs):
r"""Convert a Firedrake object `x` into a PyTorch tensor.

:arg x: Firedrake object (Function, Vector, Constant)
:kwarg gather: if True, gather data from all processes
:kwarg batched: if True, add a batch dimension to the tensor
:kwarg kwargs: additional arguments to be passed to torch.Tensor constructor
- device: device on which the tensor is allocated (default: "cpu")
- dtype: the desired data type of returned tensor (default: type of x.dat.data)
- requires_grad: if the tensor should be annotated (default: False)
"""
if isinstance(x, (Function, Vector)):
if gather:
# Gather data from all processes
x_P = self.backend.tensor(x.vector().gather(), **kwargs)
else:
# Use local data
x_P = self.backend.tensor(x.vector().get_local(), **kwargs)
if batched:
# Default behaviour: add batch dimension after converting to PyTorch
return x_P[None, :]
return x_P
elif isinstance(x, Constant):
return self.backend.tensor(x.values(), **kwargs)
elif isinstance(x, (float, int)):
if isinstance(x, float):
# Set double-precision
kwargs['dtype'] = self.backend.double
return self.backend.tensor(x, **kwargs)
else:
raise ValueError("Cannot convert %s to a torch tensor" % str(type(x)))

def from_ml_backend(self, x, V=None):
r"""Convert a PyTorch tensor `x` into a Firedrake object.

:arg x: PyTorch tensor (torch.Tensor)
:kwarg V: function space of the corresponding Function or None when `x` is to be mapped to a Constant
"""
if x.device.type != "cpu":
raise NotImplementedError("Firedrake does not support GPU tensors")

if V is None:
val = x.detach().numpy()
if val.shape == (1,):
val = val[0]
return Constant(val)
else:
x = x.detach().numpy()
x_F = Function(V, dtype=x.dtype)
x_F.vector().set_local(x)
return x_F


def get_backend(backend_name='pytorch'):
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
if backend_name == 'pytorch':
return PytorchBackend()
else:
raise NotImplementedError("The backend: %s is not supported." % backend_name)
86 changes: 86 additions & 0 deletions firedrake/pytorch_coupling/pytorch_custom_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import collections
from functools import partial

from firedrake.pytorch_coupling import get_backend
from firedrake.function import Function

from pyadjoint.reduced_functional import ReducedFunctional


backend = get_backend("pytorch")

if backend:
# PyTorch is installed
BackendFunction = backend.backend.autograd.Function
else:
class BackendFunction(object):
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
"""Dummy class that exceptions on instantiation."""
def __init__(self):
raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.")


class FiredrakeTorchOperator(BackendFunction):
"""
PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F.
`FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward
passes by directly calling the reduced functional F.

Inputs:
metadata: dictionary used to stash Firedrake objects.
*ω: PyTorch tensors representing the inputs to the Firedrake operator F

Outputs:
y: PyTorch tensor representing the output of the Firedrake operator F
"""

def __init__(self):
super(FiredrakeTorchOperator, self).__init__()

# This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()')
@staticmethod
def forward(ctx, metadata, *ω):
"""Forward pass of the PyTorch custom operator."""
F = metadata['F']
V = metadata['V_controls']
# Convert PyTorch input (i.e. controls) to Firedrake
ω_F = [backend.from_ml_backend(ωi, Vi) for ωi, Vi in zip(ω, V)]
nbouziani marked this conversation as resolved.
Show resolved Hide resolved
# Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape
y_F = F(ω_F)
# Stash metadata to the PyTorch context
ctx.metadata.update(metadata)
# Convert Firedrake output to PyTorch
y = backend.to_ml_backend(y_F)
return y.detach()

@staticmethod
def backward(ctx, grad_output):
"""Backward pass of the PyTorch custom operator."""
F = ctx.metadata['F']
V = ctx.metadata['V_output']
# Convert PyTorch gradient to Firedrake
adj_input = backend.from_ml_backend(grad_output, V)
if isinstance(adj_input, Function):
adj_input = adj_input.vector()

# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
Δω = F.derivative(adj_input=adj_input)

# Tuplify adjoint output
Δω = (Δω,) if not isinstance(Δω, collections.abc.Sequence) else Δω

# None is for metadata arg in `forward`
return None, *[backend.to_ml_backend(Δωi) for Δωi in Δω]


def torch_operator(F):
"""Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator
whose inputs and outputs are PyTorch tensors.
"""
if not isinstance(F, ReducedFunctional):
raise ValueError("F must be a ReducedFunctional")

V_output = backend.get_function_space(F.functional)
V_controls = [c.control.function_space() for c in F.controls]
metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output}
φ = partial(backend.custom_operator, metadata)
return φ
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def pytest_configure(config):
config.addinivalue_line(
"markers",
"skipcomplexnoslate: mark as skipped in complex mode due to lack of Slate")
config.addinivalue_line(
"markers",
"skiptorch: mark as skipped if PyTorch is not installed")


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -79,6 +82,10 @@ def pytest_runtest_call(item):

def pytest_collection_modifyitems(session, config, items):
from firedrake.utils import SLATE_SUPPORTS_COMPLEX
from firedrake.pytorch_coupling import get_backend

backend = get_backend("pytorch")
nbouziani marked this conversation as resolved.
Show resolved Hide resolved

for item in items:
if complex_mode:
if item.get_closest_marker("skipcomplex") is not None:
Expand All @@ -89,6 +96,10 @@ def pytest_collection_modifyitems(session, config, items):
if item.get_closest_marker("skipreal") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense unless in complex mode"))

if not backend:
if item.get_closest_marker("skiptorch") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed"))


@pytest.fixture(scope="module", autouse=True)
def check_empty_tape(request):
Expand Down
Loading