From 6535f0a1bd9ee6f22f2d5c625fee2da54f19e3fc Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 1 Dec 2022 00:15:24 +0000 Subject: [PATCH 01/48] Add backend skeleton --- .../neural_networks/backends.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 firedrake/external_operators/neural_networks/backends.py diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py new file mode 100644 index 0000000000..542a8f5922 --- /dev/null +++ b/firedrake/external_operators/neural_networks/backends.py @@ -0,0 +1,56 @@ +from firedrake.function import Function +from firedrake.cofunction import Cofunction + +from pytorch_custom_operator import CustomOperator + +import firedrake.utils as utils + + +class AbstractMLBackend(object): + + def backend(self): + raise NotImplementedError + + def to_ml_backend(self, x): + raise NotImplementedError + + def from_ml_backend(self, x, V, cofunction=None): + raise NotImplementedError + + +class PytorchBackend(AbstractMLBackend): + + @utils.cached_property + def backend(self): + try: + import torch + except ImportError: + raise ImportError("Error when trying to import PyTorch") + return torch + + @utils.cached_property + def custom_operator(self): + return CustomOperator().apply + + def to_ml_backend(self, x): + return self.backend.tensor(x.dat.data, requires_grad=True) + + def from_ml_backend(x, V, cofunction=False): + if cofunction: + u = Cofunction(V.dual()) + else: + u = Function(V) + u.vector()[:] = x.detach().numpy() + return u + + +def get_backend(backend_name): + if backend_name == 'pytorch': + return PytorchBackend() + else: + error_msg = """ The backend: "%s" is not implemented! + -> You can do so by sublcassing the `NeuralNet` class and make your own neural network class + for that backend! + See, for example, the `firedrake.external_operators.PytorchOperator` class associated with the PyTorch backend. + """ % backend_name + raise NotImplementedError(error_msg) From 754f5d635df6001c7623bbee8ac21d96c3121ae0 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 1 Feb 2023 17:33:52 +0000 Subject: [PATCH 02/48] Update backend --- .../external_operators/neural_networks/backends.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py index 542a8f5922..d118c667db 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/external_operators/neural_networks/backends.py @@ -1,5 +1,4 @@ from firedrake.function import Function -from firedrake.cofunction import Cofunction from pytorch_custom_operator import CustomOperator @@ -14,7 +13,7 @@ def backend(self): def to_ml_backend(self, x): raise NotImplementedError - def from_ml_backend(self, x, V, cofunction=None): + def from_ml_backend(self, x, V): raise NotImplementedError @@ -35,11 +34,8 @@ def custom_operator(self): def to_ml_backend(self, x): return self.backend.tensor(x.dat.data, requires_grad=True) - def from_ml_backend(x, V, cofunction=False): - if cofunction: - u = Cofunction(V.dual()) - else: - u = Function(V) + def from_ml_backend(x, V): + u = Function(V) u.vector()[:] = x.detach().numpy() return u From f1e37fd85467028f65714263a3ad4ca03897d9d4 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 1 Feb 2023 17:35:03 +0000 Subject: [PATCH 03/48] Add PyTorch custom operator --- .../pytorch_custom_operator.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 firedrake/external_operators/neural_networks/pytorch_custom_operator.py diff --git a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py new file mode 100644 index 0000000000..abbcd1d6c1 --- /dev/null +++ b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py @@ -0,0 +1,80 @@ +import collections + +import torch +import torch.autograd as torch_ad + +from firedrake.external_operators.neural_networks import get_backend +from firedrake.function import Function + + +backend = get_backend('pytorch') + + +class FiredrakeTorchOperator(torch_ad.Function): + """ + We can implement our own custom autograd Functions by subclassing + torch.autograd.Function and implementing the forward and backward passes + which operate on Tensors. + """ + + # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') + @staticmethod + def forward(ctx, metadata, *ω): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + F = metadata['F'] + V = metadata['V_controls'] + # w can be list/tuple of model parameters or firedrake type. + # Converter checks first firedrake type if not check if list/tuple check + # all elements are parameters type and then return Constant subclass (PyTorchParams) + # Convert PyTorch input (i.e. controls) to Firedrake + ω_F = [backend.from_ml_backend(ωi, Vi) for ωi, Vi in zip(ω, V)] + + # Should we turn annotation pyadjoint also if not turned on ? + + # Forward operator: `ReducedFunctional` recompute blocks on the tape + y_F = F(*ω_F) + # Attach metadata to the PyTorch context + ctx.metadata.update(metadata) + # Convert Firedrake output to PyTorch + y = backend.to_ml_backend(y_F) + return y.detach() + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + F = ctx.metadata['F'] + V = ctx.metadata['V_output'] + + adj_input = backend.from_ml_backend(grad_output, V) + if isinstance(adj_input, Function): + adj_input = adj_input.vector() + + # Compute adjoint model of the Firedrake operator `F` on `adj_input` + Δω = F.derivative(adj_input=adj_input) + + # Tuplify + Δω = (Δω,) if not isinstance(Δω, collections.abc.Sequence) else Δω + + # None is for metadata arg in `forward` + return None, *[backend.to_ml_backend(Δωi) for Δωi in Δω] + + +def to_pytorch(*args, **kwargs): + # Avoid circular import + from firedrake.external_operators.neural_networks.backends import PytorchBackend + return PytorchBackend().to_ml_backend(*args, **kwargs) + + +def from_pytorch(*args, **kwargs): + # Avoid circular import + from firedrake.external_operators.neural_networks.backends import PytorchBackend + return PytorchBackend().from_ml_backend(*args, **kwargs) From ce9ac31e519844806f63ec9417ec5365e360b740 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 1 Feb 2023 17:36:34 +0000 Subject: [PATCH 04/48] Add HybridOperator --- .../neural_networks/ml_backend_coupling.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 firedrake/external_operators/neural_networks/ml_backend_coupling.py diff --git a/firedrake/external_operators/neural_networks/ml_backend_coupling.py b/firedrake/external_operators/neural_networks/ml_backend_coupling.py new file mode 100644 index 0000000000..5ca4df4fe6 --- /dev/null +++ b/firedrake/external_operators/neural_networks/ml_backend_coupling.py @@ -0,0 +1,31 @@ +from functools import partial + +from firedrake.external_operators.neural_networks.backends import get_backend + + +class HybridOperator(object): + """ + F: Firedrake operator + """ + def __init__(self, F, backend='pytorch'): + # Add sugar syntax if F not a callable (e.g. Form or ExternalOperator) + self.F = F + self.backend = get_backend(backend) + self.custom_operator = self.backend.custom_operator + + def __call__(self, *ω): + r""" + ω can be model parameters, firedrake object or list of firedrake object + Example: Let y = f(x; θ) with f a neural network of inputs x and parameters θ + + 1) ω = θ (Inverse problem using ExternalOperator) + ... + 2) ω = y (physics-driven neural networks) + ... + """ + V_controls = [c.control.function_space() for c in self.F.controls] + F_output = self.F.functional + V_output = self.backend.get_function_space(F_output) + metadata = {'F': self.F, 'V_controls': V_controls, 'V_output': V_output} + φ = partial(self.custom_operator, metadata) + return φ(*ω) From a72a06ee73c22a8f1aab4e5c679b593745f33659 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 1 Feb 2023 17:56:29 +0000 Subject: [PATCH 05/48] Add test --- .../neural_networks/test_pytorch_coupling.py | 232 ++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 tests/external_operators/neural_networks/test_pytorch_coupling.py diff --git a/tests/external_operators/neural_networks/test_pytorch_coupling.py b/tests/external_operators/neural_networks/test_pytorch_coupling.py new file mode 100644 index 0000000000..bedb4ecea3 --- /dev/null +++ b/tests/external_operators/neural_networks/test_pytorch_coupling.py @@ -0,0 +1,232 @@ +import pytest + +import torch +import torch.nn.functional as torch_func +from torch.nn import Module, Flatten, Linear + +from firedrake import * +from firedrake_adjoint import * +from firedrake.external_operators.neural_networks.backends import get_backend +from pyadjoint.tape import get_working_tape, pause_annotation + + +@pytest.fixture(autouse=True) +def handle_taping(): + yield + tape = get_working_tape() + tape.clear_tape() + + +@pytest.fixture(autouse=True, scope="module") +def handle_annotation(): + from firedrake_adjoint import annotate_tape, continue_annotation + if not annotate_tape(): + continue_annotation() + yield + # Since importing firedrake_adjoint modifies a global variable, we need to + # pause annotations at the end of the module + annotate = annotate_tape() + if annotate: + pause_annotation() + + +@pytest.fixture(scope='module') +def mesh(): + return UnitSquareMesh(10, 10) + + +@pytest.fixture(scope='module') +def V(mesh): + return FunctionSpace(mesh, "CG", 1) + + +@pytest.fixture +def f_exact(V, mesh): + x, y = SpatialCoordinate(mesh) + return Function(V).interpolate(sin(pi * x) * sin(pi * y)) + + +class EncoderDecoder(Module): + """Build a simple toy model""" + + def __init__(self, n): + super(EncoderDecoder, self).__init__() + self.n1 = n + self.n2 = int(n/2) + self.flatten = Flatten() + self.encoder_1 = Linear(self.n1, self.n2) + self.decoder_1 = Linear(self.n2, self.n1) + + def encode(self, x): + return self.encoder_1(x) + + def decode(self, x): + return self.decoder_1(x) + + def forward(self, x): + # [batch_size, n] + x = self.flatten(x) + # [batch_size, n2] + encoded = self.encode(x) + hidden = torch_func.relu(encoded) + # [batch_size, n] + decoded = self.decode(hidden) + return torch_func.relu(decoded) + + +# Set of Firedrake operations that will be composed with PyTorch operations +def poisson_residual(u, f, V): + """Assemble the residual of a Poisson problem""" + v = TestFunction(V) + F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx + return assemble(F) + + +# Set of Firedrake operations that will be composed with PyTorch operations +def solve_poisson(f, V): + """Solve Poisson problem""" + u = Function(V) + v = TestFunction(V) + F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx + bcs = [DirichletBC(V, Constant(1.0), "on_boundary")] + # Solve PDE + solve(F == 0, u, bcs=bcs) + # Assemble Firedrake loss + return assemble(u ** 2 * dx) + + +@pytest.fixture(params=['poisson_residual', 'solve_poisson']) +def firedrake_operator(request, f_exact, V): + # Return firedrake operator and the corresponding non-control arguments + if request.param == 'poisson_residual': + return poisson_residual, (f_exact, V) + elif request.param == 'solve_poisson': + return solve_poisson, (V,) + + +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +def test_pytorch_loss_backward(V, f_exact): + """Add doc """ + + # Instantiate model + model = EncoderDecoder(V.dim()) + + # Set double precision + model.double() + + # Check that gradients are initially set to None + assert all([θi.grad is None for θi in model.parameters()]) + + # Get machine learning backend (default: PyTorch) + pytorch_backend = get_backend() + + # Model input + f = Function(V) + + # Convert f to torch.Tensor + f_P = pytorch_backend.to_ml_backend(f) + + # Forward pass + y_P = model(f_P) + + # Set control + u_F = Function(V) + c = Control(u_F) + + # Set reduced functional which expresses the Firedrake operations in terms of the control + Jhat = ReducedFunctional(poisson_residual(u_F, f_exact, V), c) + + # Construct the HybridOperator that takes a callable representing the Firedrake operations + G = HybridOperator(Jhat) + + # Compute Poisson residual in Firedrake using HybridOperator: `residual_P` is a torch.Tensor + residual_P = G(y_P) + + # Compute PyTorch loss + loss = (residual_P ** 2).sum() + + # -- Check backpropagation API -- # + loss.backward() + + # Check that gradients were propagated to model parameters + # This test doesn't check the correctness of these gradients + # -> This is checked in `test_taylor_hybrid_operator` + assert all([θi.grad is not None for θi in model.parameters()]) + + # -- Check forward operator -- # + y_F = pytorch_backend.from_ml_backend(y_P, V) + residual_F = poisson_residual(y_F, f_exact, V) + residual_P_exact = pytorch_backend.to_ml_backend(residual_F) + + assert (residual_P - residual_P_exact).detach().norm() < 1e-10 + + +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +def test_firedrake_loss_backward(V, f_exact): + """Add doc """ + + # Instantiate model + model = EncoderDecoder(V.dim()) + + # Set double precision + model.double() + + # Check that gradients are initially set to None + assert all([θi.grad is None for θi in model.parameters()]) + + # Get machine learning backend (default: PyTorch) + pytorch_backend = get_backend() + + # Model input + λ = Function(V) + + # Convert f to torch.Tensor + λ_P = pytorch_backend.to_ml_backend(λ) + + # Forward pass + f_P = model(λ_P) + + # Set control + f = Function(V) + c = Control(f) + + # Set reduced functional which expresses the Firedrake operations in terms of the control + Jhat = ReducedFunctional(solve_poisson(f, V), c) + + # Construct the HybridOperator that takes a callable representing the Firedrake operations + G = HybridOperator(Jhat) + + # Solve Poisson problem and compute the loss defined as the L2-norm of the solution + # -> `loss_P` is a torch.Tensor + loss_P = G(f_P) + + # -- Check backpropagation API -- # + loss_P.backward() + + # Check that gradients were propagated to model parameters + # This test doesn't check the correctness of these gradients + # -> This is checked in `test_taylor_hybrid_operator` + assert all([θi.grad is not None for θi in model.parameters()]) + + # -- Check forward operator -- # + f_F = pytorch_backend.from_ml_backend(f_P, V) + loss_F = solve_poisson(f_F, V) + loss_P_exact = pytorch_backend.to_ml_backend(loss_F) + + assert (loss_P - loss_P_exact).detach().norm() < 1e-10 + + +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +def test_taylor_hybrid_operator(firedrake_operator, V): + # Control value + ω = Function(V) + # Get Firedrake operator and other operator arguments + fd_op, args = firedrake_operator + # Set reduced functional + Jhat = ReducedFunctional(fd_op(ω, *args), Control(ω)) + # Define the hybrid operator + G = HybridOperator(Jhat) + # `gradcheck` is likey to fail if the inputs are not double precision (cf. https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.html) + x_P = torch.rand(V.dim(), dtype=torch.double, requires_grad=True) + # Taylor test (`eps` is the perturbation) + torch.autograd.gradcheck(G, x_P, eps=1e-6) From cd5f11d127b93fcd0951175d3e91a4dc377409db Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 1 Feb 2023 18:10:09 +0000 Subject: [PATCH 06/48] Update backend --- .../neural_networks/backends.py | 70 ++++++++++++++----- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py index d118c667db..0c1bae0bf1 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/external_operators/neural_networks/backends.py @@ -1,6 +1,6 @@ from firedrake.function import Function - -from pytorch_custom_operator import CustomOperator +from firedrake.vector import Vector +from firedrake.constant import Constant import firedrake.utils as utils @@ -11,11 +11,26 @@ def backend(self): raise NotImplementedError def to_ml_backend(self, x): + """Convert from Firedrake to ML backend + x: Firedrake object + """ raise NotImplementedError def from_ml_backend(self, x, V): + """Convert from ML backend to Firedrake + x: ML backend object + """ raise NotImplementedError + def get_function_space(self, x): + """Get function space out of x""" + if isinstance(x, Function): + return x.function_space() + elif isinstance(x, float): + return None + else: + raise ValueError('Cannot infer the function space of %s' % x) + class PytorchBackend(AbstractMLBackend): @@ -29,24 +44,47 @@ def backend(self): @utils.cached_property def custom_operator(self): - return CustomOperator().apply + from firedrake.external_operators.neural_networks.pytorch_custom_operator import FiredrakeTorchOperator + return FiredrakeTorchOperator().apply - def to_ml_backend(self, x): - return self.backend.tensor(x.dat.data, requires_grad=True) + def to_ml_backend(self, x, unsqueeze=True, unsqueeze_dim=0): + # Work out what's the right thing to do here ? + requires_grad = True + if isinstance(x, (Function, Vector)): + # Should we use `.dat.data` instead of `.dat.data_ro` to increase the state counter ? + x_P = self.backend.tensor(x.dat.data_ro, requires_grad=requires_grad) + # Default behaviour: unsqueeze after converting to PyTorch + # Shape: [1, x.dat.shape] + if unsqueeze: + x_P = x_P.unsqueeze(unsqueeze_dim) + return x_P + # Add case subclass constant representing theta + # elif isinstance(x, ...): + elif isinstance(x, Constant): + return self.backend.tensor(x.values(), requires_grad=requires_grad) + elif isinstance(x, (float, int)): + # Covers pyadjoint AdjFloat as well + return self.backend.tensor(x, requires_grad=requires_grad) + else: + raise ValueError("Cannot convert %s to the ML backend environment" % str(type(x))) - def from_ml_backend(x, V): - u = Function(V) - u.vector()[:] = x.detach().numpy() - return u + def from_ml_backend(self, x, V=None): + if V is None: + val = x.detach().numpy() + return Constant(val) + else: + u = Function(V) + # Default behaviour: squeeze before converting to Firedrake + # This is motivated by the fact that assigning to numpy array to `u` will automatically squeeze + # the batch dimension behind the scenes + # Shape: [x.shape] + x = x.squeeze(0) + u.vector()[:] = x.detach().numpy() + return u -def get_backend(backend_name): +def get_backend(backend_name='pytorch'): if backend_name == 'pytorch': return PytorchBackend() else: - error_msg = """ The backend: "%s" is not implemented! - -> You can do so by sublcassing the `NeuralNet` class and make your own neural network class - for that backend! - See, for example, the `firedrake.external_operators.PytorchOperator` class associated with the PyTorch backend. - """ % backend_name - raise NotImplementedError(error_msg) + raise NotImplementedError("The backend: %s is not supported." % backend_name) From f388b88be4cb7777a5827054622d0afd721827a3 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sun, 5 Feb 2023 03:32:13 +0000 Subject: [PATCH 07/48] Fix reduced functional call in FiredrakeTorchOperator --- .../neural_networks/pytorch_custom_operator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py index abbcd1d6c1..54e22e4eaf 100644 --- a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py +++ b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py @@ -1,6 +1,5 @@ import collections -import torch import torch.autograd as torch_ad from firedrake.external_operators.neural_networks import get_backend @@ -37,8 +36,8 @@ def forward(ctx, metadata, *ω): # Should we turn annotation pyadjoint also if not turned on ? # Forward operator: `ReducedFunctional` recompute blocks on the tape - y_F = F(*ω_F) - # Attach metadata to the PyTorch context + y_F = F(ω_F) + # Attach metadata to the PyTorch contextx ctx.metadata.update(metadata) # Convert Firedrake output to PyTorch y = backend.to_ml_backend(y_F) From 422b1d5a5d0fbda7b0322a77cf8f29bcbd6f847e Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sun, 5 Feb 2023 03:32:47 +0000 Subject: [PATCH 08/48] Add torch_op helper function --- firedrake/__init__.py | 1 + .../external_operators/neural_networks/ml_backend_coupling.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 3fd3eb910c..8e25e01e1c 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,6 +103,7 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 +from firedrake.external_operators import * from firedrake.logging import * # Set default log level diff --git a/firedrake/external_operators/neural_networks/ml_backend_coupling.py b/firedrake/external_operators/neural_networks/ml_backend_coupling.py index 5ca4df4fe6..c1c5463d13 100644 --- a/firedrake/external_operators/neural_networks/ml_backend_coupling.py +++ b/firedrake/external_operators/neural_networks/ml_backend_coupling.py @@ -29,3 +29,7 @@ def __call__(self, *ω): metadata = {'F': self.F, 'V_controls': V_controls, 'V_output': V_output} φ = partial(self.custom_operator, metadata) return φ(*ω) + + +def torch_op(*args, **kwargs): + return HybridOperator(*args, **kwargs) From 0288d0aa58bdfd17b5d33339d930390145532955 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Tue, 7 Feb 2023 13:15:15 +0000 Subject: [PATCH 09/48] Rename torch_op -> torch_operator --- firedrake/external_operators/__init__.py | 1 + firedrake/external_operators/neural_networks/__init__.py | 2 ++ .../external_operators/neural_networks/ml_backend_coupling.py | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 firedrake/external_operators/__init__.py create mode 100644 firedrake/external_operators/neural_networks/__init__.py diff --git a/firedrake/external_operators/__init__.py b/firedrake/external_operators/__init__.py new file mode 100644 index 0000000000..d37df34173 --- /dev/null +++ b/firedrake/external_operators/__init__.py @@ -0,0 +1 @@ +from firedrake.external_operators.neural_networks import * \ No newline at end of file diff --git a/firedrake/external_operators/neural_networks/__init__.py b/firedrake/external_operators/neural_networks/__init__.py new file mode 100644 index 0000000000..6be941f2dd --- /dev/null +++ b/firedrake/external_operators/neural_networks/__init__.py @@ -0,0 +1,2 @@ +from .backends import get_backend +from .ml_backend_coupling import HybridOperator, torch_operator \ No newline at end of file diff --git a/firedrake/external_operators/neural_networks/ml_backend_coupling.py b/firedrake/external_operators/neural_networks/ml_backend_coupling.py index c1c5463d13..bc5fff7b28 100644 --- a/firedrake/external_operators/neural_networks/ml_backend_coupling.py +++ b/firedrake/external_operators/neural_networks/ml_backend_coupling.py @@ -31,5 +31,5 @@ def __call__(self, *ω): return φ(*ω) -def torch_op(*args, **kwargs): +def torch_operator(*args, **kwargs): return HybridOperator(*args, **kwargs) From e01fec55ca93bc76caccc77f91ab78f17856d4c4 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Fri, 10 Feb 2023 22:32:38 +0000 Subject: [PATCH 10/48] Remove HybridOperator in favour of torch_operator --- .../neural_networks/__init__.py | 2 +- .../neural_networks/ml_backend_coupling.py | 35 ------------------- .../pytorch_custom_operator.py | 26 ++++++++------ .../neural_networks/test_pytorch_coupling.py | 20 +++++------ 4 files changed, 26 insertions(+), 57 deletions(-) delete mode 100644 firedrake/external_operators/neural_networks/ml_backend_coupling.py diff --git a/firedrake/external_operators/neural_networks/__init__.py b/firedrake/external_operators/neural_networks/__init__.py index 6be941f2dd..dd31093899 100644 --- a/firedrake/external_operators/neural_networks/__init__.py +++ b/firedrake/external_operators/neural_networks/__init__.py @@ -1,2 +1,2 @@ from .backends import get_backend -from .ml_backend_coupling import HybridOperator, torch_operator \ No newline at end of file +from .pytorch_custom_operator import torch_operator \ No newline at end of file diff --git a/firedrake/external_operators/neural_networks/ml_backend_coupling.py b/firedrake/external_operators/neural_networks/ml_backend_coupling.py deleted file mode 100644 index bc5fff7b28..0000000000 --- a/firedrake/external_operators/neural_networks/ml_backend_coupling.py +++ /dev/null @@ -1,35 +0,0 @@ -from functools import partial - -from firedrake.external_operators.neural_networks.backends import get_backend - - -class HybridOperator(object): - """ - F: Firedrake operator - """ - def __init__(self, F, backend='pytorch'): - # Add sugar syntax if F not a callable (e.g. Form or ExternalOperator) - self.F = F - self.backend = get_backend(backend) - self.custom_operator = self.backend.custom_operator - - def __call__(self, *ω): - r""" - ω can be model parameters, firedrake object or list of firedrake object - Example: Let y = f(x; θ) with f a neural network of inputs x and parameters θ - - 1) ω = θ (Inverse problem using ExternalOperator) - ... - 2) ω = y (physics-driven neural networks) - ... - """ - V_controls = [c.control.function_space() for c in self.F.controls] - F_output = self.F.functional - V_output = self.backend.get_function_space(F_output) - metadata = {'F': self.F, 'V_controls': V_controls, 'V_output': V_output} - φ = partial(self.custom_operator, metadata) - return φ(*ω) - - -def torch_operator(*args, **kwargs): - return HybridOperator(*args, **kwargs) diff --git a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py index 54e22e4eaf..9b3d3ae5d8 100644 --- a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py +++ b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py @@ -1,10 +1,12 @@ import collections - import torch.autograd as torch_ad +from functools import partial from firedrake.external_operators.neural_networks import get_backend from firedrake.function import Function +from pyadjoint.reduced_functional import ReducedFunctional + backend = get_backend('pytorch') @@ -67,13 +69,15 @@ def backward(ctx, grad_output): return None, *[backend.to_ml_backend(Δωi) for Δωi in Δω] -def to_pytorch(*args, **kwargs): - # Avoid circular import - from firedrake.external_operators.neural_networks.backends import PytorchBackend - return PytorchBackend().to_ml_backend(*args, **kwargs) - - -def from_pytorch(*args, **kwargs): - # Avoid circular import - from firedrake.external_operators.neural_networks.backends import PytorchBackend - return PytorchBackend().from_ml_backend(*args, **kwargs) +def torch_operator(F): + """Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator + whose inputs and outputs are PyTorch tensors. + """ + if not isinstance(F, ReducedFunctional): + raise ValueError("F must be a ReducedFunctional") + + V_output = backend.get_function_space(F.functional) + V_controls = [c.control.function_space() for c in F.controls] + metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output} + φ = partial(backend.custom_operator, metadata) + return φ diff --git a/tests/external_operators/neural_networks/test_pytorch_coupling.py b/tests/external_operators/neural_networks/test_pytorch_coupling.py index bedb4ecea3..26ca50be8a 100644 --- a/tests/external_operators/neural_networks/test_pytorch_coupling.py +++ b/tests/external_operators/neural_networks/test_pytorch_coupling.py @@ -136,10 +136,10 @@ def test_pytorch_loss_backward(V, f_exact): # Set reduced functional which expresses the Firedrake operations in terms of the control Jhat = ReducedFunctional(poisson_residual(u_F, f_exact, V), c) - # Construct the HybridOperator that takes a callable representing the Firedrake operations - G = HybridOperator(Jhat) + # Construct the torch operator that takes a callable representing the Firedrake operations + G = torch_operator(Jhat) - # Compute Poisson residual in Firedrake using HybridOperator: `residual_P` is a torch.Tensor + # Compute Poisson residual in Firedrake using the torch operator: `residual_P` is a torch.Tensor residual_P = G(y_P) # Compute PyTorch loss @@ -150,7 +150,7 @@ def test_pytorch_loss_backward(V, f_exact): # Check that gradients were propagated to model parameters # This test doesn't check the correctness of these gradients - # -> This is checked in `test_taylor_hybrid_operator` + # -> This is checked in `test_taylor_torch_operator` assert all([θi.grad is not None for θi in model.parameters()]) # -- Check forward operator -- # @@ -193,8 +193,8 @@ def test_firedrake_loss_backward(V, f_exact): # Set reduced functional which expresses the Firedrake operations in terms of the control Jhat = ReducedFunctional(solve_poisson(f, V), c) - # Construct the HybridOperator that takes a callable representing the Firedrake operations - G = HybridOperator(Jhat) + # Construct the torch operator that takes a callable representing the Firedrake operations + G = torch_operator(Jhat) # Solve Poisson problem and compute the loss defined as the L2-norm of the solution # -> `loss_P` is a torch.Tensor @@ -205,7 +205,7 @@ def test_firedrake_loss_backward(V, f_exact): # Check that gradients were propagated to model parameters # This test doesn't check the correctness of these gradients - # -> This is checked in `test_taylor_hybrid_operator` + # -> This is checked in `test_taylor_torch_operator` assert all([θi.grad is not None for θi in model.parameters()]) # -- Check forward operator -- # @@ -217,15 +217,15 @@ def test_firedrake_loss_backward(V, f_exact): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_taylor_hybrid_operator(firedrake_operator, V): +def test_taylor_torch_operator(firedrake_operator, V): # Control value ω = Function(V) # Get Firedrake operator and other operator arguments fd_op, args = firedrake_operator # Set reduced functional Jhat = ReducedFunctional(fd_op(ω, *args), Control(ω)) - # Define the hybrid operator - G = HybridOperator(Jhat) + # Define the torch operator + G = torch_operator(Jhat) # `gradcheck` is likey to fail if the inputs are not double precision (cf. https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.html) x_P = torch.rand(V.dim(), dtype=torch.double, requires_grad=True) # Taylor test (`eps` is the perturbation) From 22f654c18d0b1b7cf7cc1e84a36e9483c2edd351 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sat, 11 Feb 2023 00:12:43 +0000 Subject: [PATCH 11/48] Update comments FiredrakeTorchOperator --- .../pytorch_custom_operator.py | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py index 9b3d3ae5d8..61dee4fffe 100644 --- a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py +++ b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py @@ -13,33 +13,29 @@ class FiredrakeTorchOperator(torch_ad.Function): """ - We can implement our own custom autograd Functions by subclassing - torch.autograd.Function and implementing the forward and backward passes - which operate on Tensors. + PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. + `FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward + passes by directly calling the reduced functional F. + + Inputs: + metadata: dictionary used to stash Firedrake objects. + *ω: PyTorch tensors representing the inputs to the Firedrake operator F + + Outputs: + y: PyTorch tensor representing the output of the Firedrake operator F """ # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') @staticmethod def forward(ctx, metadata, *ω): - """ - In the forward pass we receive a Tensor containing the input and return - a Tensor containing the output. ctx is a context object that can be used - to stash information for backward computation. You can cache arbitrary - objects for use in the backward pass using the ctx.save_for_backward method. - """ + """Forward pass of the PyTorch custom operator.""" F = metadata['F'] V = metadata['V_controls'] - # w can be list/tuple of model parameters or firedrake type. - # Converter checks first firedrake type if not check if list/tuple check - # all elements are parameters type and then return Constant subclass (PyTorchParams) # Convert PyTorch input (i.e. controls) to Firedrake ω_F = [backend.from_ml_backend(ωi, Vi) for ωi, Vi in zip(ω, V)] - - # Should we turn annotation pyadjoint also if not turned on ? - - # Forward operator: `ReducedFunctional` recompute blocks on the tape + # Forward operator: Delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape y_F = F(ω_F) - # Attach metadata to the PyTorch contextx + # Stash metadata to the PyTorch context ctx.metadata.update(metadata) # Convert Firedrake output to PyTorch y = backend.to_ml_backend(y_F) @@ -47,19 +43,15 @@ def forward(ctx, metadata, *ω): @staticmethod def backward(ctx, grad_output): - """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. - """ + """Backward pass of the PyTorch custom operator.""" F = ctx.metadata['F'] V = ctx.metadata['V_output'] - + # Convert PyTorch gradient to Firedrake adj_input = backend.from_ml_backend(grad_output, V) if isinstance(adj_input, Function): adj_input = adj_input.vector() - # Compute adjoint model of the Firedrake operator `F` on `adj_input` + # Compute adjoint model of `F`: Delegated to pyadjoint.ReducedFunctional Δω = F.derivative(adj_input=adj_input) # Tuplify From 6c2072bec69cf5755d29bb86593ccac1bdf1d098 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sat, 11 Feb 2023 03:18:08 +0000 Subject: [PATCH 12/48] Update pytorch backend mappings --- .../neural_networks/backends.py | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py index 0c1bae0bf1..c041dc27b8 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/external_operators/neural_networks/backends.py @@ -26,10 +26,12 @@ def get_function_space(self, x): """Get function space out of x""" if isinstance(x, Function): return x.function_space() + elif isinstance(x, Vector): + return self.get_function_space(x.function) elif isinstance(x, float): return None else: - raise ValueError('Cannot infer the function space of %s' % x) + raise ValueError("Cannot infer the function space of %s" % x) class PytorchBackend(AbstractMLBackend): @@ -47,40 +49,52 @@ def custom_operator(self): from firedrake.external_operators.neural_networks.pytorch_custom_operator import FiredrakeTorchOperator return FiredrakeTorchOperator().apply - def to_ml_backend(self, x, unsqueeze=True, unsqueeze_dim=0): - # Work out what's the right thing to do here ? - requires_grad = True + def to_ml_backend(self, x, gather=False, batched=True, **kwargs): + """ Convert a Firedrake object `x` into a PyTorch tensor + + x: Firedrake object (Function, Vector, Constant) + gather: if True, gather data from all processes + batched: if True, add a batch dimension to the tensor + kwargs: additional arguments to be passed to torch.Tensor constructor + - device: device on which the tensor is allocated (default: "cpu") + - dtype: the desired data type of returned tensor (default: type of x.dat.data) + - requires_grad: if the tensor should be annotated (default: False) + """ if isinstance(x, (Function, Vector)): - # Should we use `.dat.data` instead of `.dat.data_ro` to increase the state counter ? - x_P = self.backend.tensor(x.dat.data_ro, requires_grad=requires_grad) - # Default behaviour: unsqueeze after converting to PyTorch - # Shape: [1, x.dat.shape] - if unsqueeze: - x_P = x_P.unsqueeze(unsqueeze_dim) + # State counter: get_local does a copy and increase the state counter while gather does not. + # We probably always want to increase the state counter and therefore should do something for the gather case + if gather: + # Gather data from all processes + x_P = self.backend.tensor(x.vector().gather(), **kwargs) + else: + # Use local data + x_P = self.backend.tensor(x.vector().get_local(), **kwargs) + if batched: + # Default behaviour: add batch dimension after converting to PyTorch + return x_P[None, :] return x_P - # Add case subclass constant representing theta - # elif isinstance(x, ...): elif isinstance(x, Constant): - return self.backend.tensor(x.values(), requires_grad=requires_grad) + return self.backend.tensor(x.values(), **kwargs) elif isinstance(x, (float, int)): - # Covers pyadjoint AdjFloat as well - return self.backend.tensor(x, requires_grad=requires_grad) + return self.backend.tensor(x, **kwargs) else: - raise ValueError("Cannot convert %s to the ML backend environment" % str(type(x))) + raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) - def from_ml_backend(self, x, V=None): + def from_ml_backend(self, x, V=None, gather=False): if V is None: val = x.detach().numpy() + if val.shape == (1,): + val = val[0] return Constant(val) else: - u = Function(V) + x_F = Function(V) # Default behaviour: squeeze before converting to Firedrake # This is motivated by the fact that assigning to numpy array to `u` will automatically squeeze # the batch dimension behind the scenes # Shape: [x.shape] x = x.squeeze(0) - u.vector()[:] = x.detach().numpy() - return u + x_F.vector().set_local(x.detach().numpy()) + return x_F def get_backend(backend_name='pytorch'): From 0fa8738f17c5b48d822a5f3f7901fcf8c4e18958 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sat, 11 Feb 2023 03:32:17 +0000 Subject: [PATCH 13/48] Use torch tensor type --- firedrake/external_operators/neural_networks/backends.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py index c041dc27b8..b897fe6e08 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/external_operators/neural_networks/backends.py @@ -81,19 +81,23 @@ def to_ml_backend(self, x, gather=False, batched=True, **kwargs): raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) def from_ml_backend(self, x, V=None, gather=False): + if x.device.type != "cpu": + raise NotImplementedError("Firedrake does not support GPU tensors") + if V is None: val = x.detach().numpy() if val.shape == (1,): val = val[0] return Constant(val) else: - x_F = Function(V) + x = x.detach().numpy() + x_F = Function(V, dtype=x.dtype) # Default behaviour: squeeze before converting to Firedrake # This is motivated by the fact that assigning to numpy array to `u` will automatically squeeze # the batch dimension behind the scenes # Shape: [x.shape] x = x.squeeze(0) - x_F.vector().set_local(x.detach().numpy()) + x_F.vector().set_local(x) return x_F From 0772a4197439c7bc2de8ece5873f29611a23848f Mon Sep 17 00:00:00 2001 From: nbouziani Date: Mon, 6 Mar 2023 13:55:54 +0000 Subject: [PATCH 14/48] Remove squeezing when mapping from pytorch --- firedrake/external_operators/neural_networks/backends.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py index b897fe6e08..eab0725186 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/external_operators/neural_networks/backends.py @@ -92,11 +92,6 @@ def from_ml_backend(self, x, V=None, gather=False): else: x = x.detach().numpy() x_F = Function(V, dtype=x.dtype) - # Default behaviour: squeeze before converting to Firedrake - # This is motivated by the fact that assigning to numpy array to `u` will automatically squeeze - # the batch dimension behind the scenes - # Shape: [x.shape] - x = x.squeeze(0) x_F.vector().set_local(x) return x_F From d5eca4785a58708df1bc773f68c60a7165787d64 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Tue, 7 Mar 2023 16:31:18 +0000 Subject: [PATCH 15/48] Fix single-precision casting of AdjFloat --- firedrake/external_operators/neural_networks/backends.py | 4 +++- .../neural_networks/pytorch_custom_operator.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/external_operators/neural_networks/backends.py index eab0725186..644daba04d 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/external_operators/neural_networks/backends.py @@ -1,3 +1,5 @@ +import numpy as np + from firedrake.function import Function from firedrake.vector import Vector from firedrake.constant import Constant @@ -76,7 +78,7 @@ def to_ml_backend(self, x, gather=False, batched=True, **kwargs): elif isinstance(x, Constant): return self.backend.tensor(x.values(), **kwargs) elif isinstance(x, (float, int)): - return self.backend.tensor(x, **kwargs) + return self.backend.tensor(np.array(x), **kwargs) else: raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) diff --git a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py index 61dee4fffe..f2df6fe373 100644 --- a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py +++ b/firedrake/external_operators/neural_networks/pytorch_custom_operator.py @@ -33,7 +33,7 @@ def forward(ctx, metadata, *ω): V = metadata['V_controls'] # Convert PyTorch input (i.e. controls) to Firedrake ω_F = [backend.from_ml_backend(ωi, Vi) for ωi, Vi in zip(ω, V)] - # Forward operator: Delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape + # Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape y_F = F(ω_F) # Stash metadata to the PyTorch context ctx.metadata.update(metadata) @@ -51,10 +51,10 @@ def backward(ctx, grad_output): if isinstance(adj_input, Function): adj_input = adj_input.vector() - # Compute adjoint model of `F`: Delegated to pyadjoint.ReducedFunctional + # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional Δω = F.derivative(adj_input=adj_input) - # Tuplify + # Tuplify adjoint output Δω = (Δω,) if not isinstance(Δω, collections.abc.Sequence) else Δω # None is for metadata arg in `forward` From e8195242d3a8239dfe48865cd6823c1d46e08893 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Tue, 7 Mar 2023 23:48:20 +0000 Subject: [PATCH 16/48] Move pytorch coupling code to preconditioners/ --- firedrake/__init__.py | 1 - firedrake/external_operators/__init__.py | 1 - firedrake/preconditioners/__init__.py | 1 + .../pytorch_coupling}/__init__.py | 0 .../pytorch_coupling}/backends.py | 2 +- .../pytorch_coupling}/pytorch_custom_operator.py | 2 +- 6 files changed, 3 insertions(+), 4 deletions(-) delete mode 100644 firedrake/external_operators/__init__.py rename firedrake/{external_operators/neural_networks => preconditioners/pytorch_coupling}/__init__.py (100%) rename firedrake/{external_operators/neural_networks => preconditioners/pytorch_coupling}/backends.py (97%) rename firedrake/{external_operators/neural_networks => preconditioners/pytorch_coupling}/pytorch_custom_operator.py (97%) diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 8e25e01e1c..3fd3eb910c 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,7 +103,6 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 -from firedrake.external_operators import * from firedrake.logging import * # Set default log level diff --git a/firedrake/external_operators/__init__.py b/firedrake/external_operators/__init__.py deleted file mode 100644 index d37df34173..0000000000 --- a/firedrake/external_operators/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from firedrake.external_operators.neural_networks import * \ No newline at end of file diff --git a/firedrake/preconditioners/__init__.py b/firedrake/preconditioners/__init__.py index e63a3c38f1..a1f6a0a110 100644 --- a/firedrake/preconditioners/__init__.py +++ b/firedrake/preconditioners/__init__.py @@ -11,3 +11,4 @@ from firedrake.preconditioners.hypre_ads import * # noqa: F401 from firedrake.preconditioners.fdm import * # noqa: F401 from firedrake.preconditioners.facet_split import * # noqa: F401 +from firedrake.preconditioners.pytorch_coupling import * # noqa: F401 diff --git a/firedrake/external_operators/neural_networks/__init__.py b/firedrake/preconditioners/pytorch_coupling/__init__.py similarity index 100% rename from firedrake/external_operators/neural_networks/__init__.py rename to firedrake/preconditioners/pytorch_coupling/__init__.py diff --git a/firedrake/external_operators/neural_networks/backends.py b/firedrake/preconditioners/pytorch_coupling/backends.py similarity index 97% rename from firedrake/external_operators/neural_networks/backends.py rename to firedrake/preconditioners/pytorch_coupling/backends.py index 644daba04d..5786f6044a 100644 --- a/firedrake/external_operators/neural_networks/backends.py +++ b/firedrake/preconditioners/pytorch_coupling/backends.py @@ -48,7 +48,7 @@ def backend(self): @utils.cached_property def custom_operator(self): - from firedrake.external_operators.neural_networks.pytorch_custom_operator import FiredrakeTorchOperator + from firedrake.preconditioners.pytorch_coupling.pytorch_custom_operator import FiredrakeTorchOperator return FiredrakeTorchOperator().apply def to_ml_backend(self, x, gather=False, batched=True, **kwargs): diff --git a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py b/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py similarity index 97% rename from firedrake/external_operators/neural_networks/pytorch_custom_operator.py rename to firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py index f2df6fe373..7ee98a26ea 100644 --- a/firedrake/external_operators/neural_networks/pytorch_custom_operator.py +++ b/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py @@ -2,7 +2,7 @@ import torch.autograd as torch_ad from functools import partial -from firedrake.external_operators.neural_networks import get_backend +from firedrake.preconditioners.pytorch_coupling import get_backend from firedrake.function import Function from pyadjoint.reduced_functional import ReducedFunctional From 76e7f895de30af84488278fd955db00b3e084511 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 00:01:20 +0000 Subject: [PATCH 17/48] Clean up tests --- .../test_pytorch_coupling.py | 71 +++++++++---------- 1 file changed, 33 insertions(+), 38 deletions(-) rename tests/{external_operators/neural_networks => regression}/test_pytorch_coupling.py (79%) diff --git a/tests/external_operators/neural_networks/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py similarity index 79% rename from tests/external_operators/neural_networks/test_pytorch_coupling.py rename to tests/regression/test_pytorch_coupling.py index 26ca50be8a..36a626f7b7 100644 --- a/tests/external_operators/neural_networks/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -6,7 +6,6 @@ from firedrake import * from firedrake_adjoint import * -from firedrake.external_operators.neural_networks.backends import get_backend from pyadjoint.tape import get_working_tape, pause_annotation @@ -30,12 +29,12 @@ def handle_annotation(): pause_annotation() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mesh(): return UnitSquareMesh(10, 10) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def V(mesh): return FunctionSpace(mesh, "CG", 1) @@ -51,27 +50,25 @@ class EncoderDecoder(Module): def __init__(self, n): super(EncoderDecoder, self).__init__() - self.n1 = n - self.n2 = int(n/2) + self.n = n + self.m = int(n/2) self.flatten = Flatten() - self.encoder_1 = Linear(self.n1, self.n2) - self.decoder_1 = Linear(self.n2, self.n1) + self.linear_encoder = Linear(self.n, self.m) + self.linear_decoder = Linear(self.m, self.n) def encode(self, x): - return self.encoder_1(x) + return torch_func.relu(self.linear_encoder(x)) def decode(self, x): - return self.decoder_1(x) + return torch_func.relu(self.linear_decoder(x)) def forward(self, x): # [batch_size, n] x = self.flatten(x) - # [batch_size, n2] - encoded = self.encode(x) - hidden = torch_func.relu(encoded) + # [batch_size, m] + hidden = self.encode(x) # [batch_size, n] - decoded = self.decode(hidden) - return torch_func.relu(decoded) + return self.decode(hidden) # Set of Firedrake operations that will be composed with PyTorch operations @@ -84,7 +81,7 @@ def poisson_residual(u, f, V): # Set of Firedrake operations that will be composed with PyTorch operations def solve_poisson(f, V): - """Solve Poisson problem""" + """Solve Poisson problem with homogeneous Dirichlet boundary conditions""" u = Function(V) v = TestFunction(V) F = (inner(grad(u), grad(v)) + inner(u, v) - inner(f, v)) * dx @@ -95,18 +92,18 @@ def solve_poisson(f, V): return assemble(u ** 2 * dx) -@pytest.fixture(params=['poisson_residual', 'solve_poisson']) +@pytest.fixture(params=["poisson_residual", "solve_poisson"]) def firedrake_operator(request, f_exact, V): # Return firedrake operator and the corresponding non-control arguments - if request.param == 'poisson_residual': + if request.param == "poisson_residual": return poisson_residual, (f_exact, V) - elif request.param == 'solve_poisson': + elif request.param == "solve_poisson": return solve_poisson, (V,) @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done def test_pytorch_loss_backward(V, f_exact): - """Add doc """ + """Test backpropagation through a vector-valued Firedrake operator""" # Instantiate model model = EncoderDecoder(V.dim()) @@ -120,27 +117,24 @@ def test_pytorch_loss_backward(V, f_exact): # Get machine learning backend (default: PyTorch) pytorch_backend = get_backend() - # Model input - f = Function(V) - - # Convert f to torch.Tensor - f_P = pytorch_backend.to_ml_backend(f) + # Convert f_exact to torch.Tensor + f_P = pytorch_backend.to_ml_backend(f_exact) # Forward pass - y_P = model(f_P) + u_P = model(f_P) # Set control - u_F = Function(V) - c = Control(u_F) + u = Function(V) + c = Control(u) # Set reduced functional which expresses the Firedrake operations in terms of the control - Jhat = ReducedFunctional(poisson_residual(u_F, f_exact, V), c) + Jhat = ReducedFunctional(poisson_residual(u, f_exact, V), c) # Construct the torch operator that takes a callable representing the Firedrake operations G = torch_operator(Jhat) # Compute Poisson residual in Firedrake using the torch operator: `residual_P` is a torch.Tensor - residual_P = G(y_P) + residual_P = G(u_P) # Compute PyTorch loss loss = (residual_P ** 2).sum() @@ -154,16 +148,16 @@ def test_pytorch_loss_backward(V, f_exact): assert all([θi.grad is not None for θi in model.parameters()]) # -- Check forward operator -- # - y_F = pytorch_backend.from_ml_backend(y_P, V) - residual_F = poisson_residual(y_F, f_exact, V) - residual_P_exact = pytorch_backend.to_ml_backend(residual_F) + u = pytorch_backend.from_ml_backend(u_P, V) + residual = poisson_residual(u, f_exact, V) + residual_P_exact = pytorch_backend.to_ml_backend(residual) assert (residual_P - residual_P_exact).detach().norm() < 1e-10 @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_firedrake_loss_backward(V, f_exact): - """Add doc """ +def test_firedrake_loss_backward(V): + """Test backpropagation through a scalar-valued Firedrake operator""" # Instantiate model model = EncoderDecoder(V.dim()) @@ -209,15 +203,16 @@ def test_firedrake_loss_backward(V, f_exact): assert all([θi.grad is not None for θi in model.parameters()]) # -- Check forward operator -- # - f_F = pytorch_backend.from_ml_backend(f_P, V) - loss_F = solve_poisson(f_F, V) - loss_P_exact = pytorch_backend.to_ml_backend(loss_F) + f = pytorch_backend.from_ml_backend(f_P, V) + loss = solve_poisson(f, V) + loss_P_exact = pytorch_backend.to_ml_backend(loss) assert (loss_P - loss_P_exact).detach().norm() < 1e-10 @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done def test_taylor_torch_operator(firedrake_operator, V): + """Taylor test for the torch operator""" # Control value ω = Function(V) # Get Firedrake operator and other operator arguments @@ -229,4 +224,4 @@ def test_taylor_torch_operator(firedrake_operator, V): # `gradcheck` is likey to fail if the inputs are not double precision (cf. https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.html) x_P = torch.rand(V.dim(), dtype=torch.double, requires_grad=True) # Taylor test (`eps` is the perturbation) - torch.autograd.gradcheck(G, x_P, eps=1e-6) + assert torch.autograd.gradcheck(G, x_P, eps=1e-6) From 2dfb9b81f4edabd271f203cea5362c30355eace2 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 01:24:54 +0000 Subject: [PATCH 18/48] Fix lint and doc --- firedrake/preconditioners/pytorch_coupling/__init__.py | 4 ++-- .../pytorch_coupling/pytorch_custom_operator.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/firedrake/preconditioners/pytorch_coupling/__init__.py b/firedrake/preconditioners/pytorch_coupling/__init__.py index dd31093899..e15b5199f7 100644 --- a/firedrake/preconditioners/pytorch_coupling/__init__.py +++ b/firedrake/preconditioners/pytorch_coupling/__init__.py @@ -1,2 +1,2 @@ -from .backends import get_backend -from .pytorch_custom_operator import torch_operator \ No newline at end of file +from .backends import get_backend # noqa: F401 +from .pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py b/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py index 7ee98a26ea..e08190de8d 100644 --- a/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py +++ b/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py @@ -1,5 +1,4 @@ import collections -import torch.autograd as torch_ad from functools import partial from firedrake.preconditioners.pytorch_coupling import get_backend @@ -11,7 +10,7 @@ backend = get_backend('pytorch') -class FiredrakeTorchOperator(torch_ad.Function): +class FiredrakeTorchOperator(backend.backend.autograd.Function): """ PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. `FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward From 0f511d7752dabe9b49c42c57f06cfb49cc9e4ba2 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 01:25:18 +0000 Subject: [PATCH 19/48] Add torch to test dependencies --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 04f377d84f..a9566464d5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -55,6 +55,7 @@ jobs: python $(which firedrake-clean) python -m pip install pytest-cov pytest-timeout pytest-xdist pytest-timeout python -m pip list + python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - name: Test Firedrake run: | . ../build/bin/activate From ad7a008874fa6496a55337f7e329ded03372676b Mon Sep 17 00:00:00 2001 From: Nacime Bouziani <48448063+nbouziani@users.noreply.github.com> Date: Wed, 8 Mar 2023 01:41:59 +0000 Subject: [PATCH 20/48] Remove spurious torch packages --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a9566464d5..57422319a2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -55,7 +55,7 @@ jobs: python $(which firedrake-clean) python -m pip install pytest-cov pytest-timeout pytest-xdist pytest-timeout python -m pip list - python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu - name: Test Firedrake run: | . ../build/bin/activate From 2e6c2301a5b8b721eb12801e7305dfc083318e08 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 13:21:22 +0000 Subject: [PATCH 21/48] Remove spurious torch packages --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a9566464d5..57422319a2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -55,7 +55,7 @@ jobs: python $(which firedrake-clean) python -m pip install pytest-cov pytest-timeout pytest-xdist pytest-timeout python -m pip list - python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu - name: Test Firedrake run: | . ../build/bin/activate From 625ef3c7a901edac3751304d985107035c575579 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 13:42:57 +0000 Subject: [PATCH 22/48] Handle case where PyTorch is not installed using the backend --- .../pytorch_coupling/backends.py | 17 ++++- .../pytorch_custom_operator.py | 16 +++- tests/conftest.py | 11 +++ tests/regression/test_pytorch_coupling.py | 73 ++++++++++--------- 4 files changed, 75 insertions(+), 42 deletions(-) diff --git a/firedrake/preconditioners/pytorch_coupling/backends.py b/firedrake/preconditioners/pytorch_coupling/backends.py index 5786f6044a..cf3f81f6a3 100644 --- a/firedrake/preconditioners/pytorch_coupling/backends.py +++ b/firedrake/preconditioners/pytorch_coupling/backends.py @@ -38,13 +38,22 @@ def get_function_space(self, x): class PytorchBackend(AbstractMLBackend): - @utils.cached_property - def backend(self): + def __init__(self): try: import torch + self._backend = torch except ImportError: - raise ImportError("Error when trying to import PyTorch") - return torch + self._backend = None + + @property + def backend(self): + if self: + return self._backend + else: + raise ImportError("Error when trying to import PyTorch.") + + def __bool__(self): + return self._backend is not None @utils.cached_property def custom_operator(self): diff --git a/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py b/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py index e08190de8d..abe478ae09 100644 --- a/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py +++ b/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py @@ -7,10 +7,19 @@ from pyadjoint.reduced_functional import ReducedFunctional -backend = get_backend('pytorch') +backend = get_backend("pytorch") +if backend: + # PyTorch is installed + BackendFunction = backend.backend.autograd.Function +else: + class BackendFunction(object): + """Dummy class that exceptions on instantiation.""" + def __init__(self): + raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") -class FiredrakeTorchOperator(backend.backend.autograd.Function): + +class FiredrakeTorchOperator(BackendFunction): """ PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. `FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward @@ -24,6 +33,9 @@ class FiredrakeTorchOperator(backend.backend.autograd.Function): y: PyTorch tensor representing the output of the Firedrake operator F """ + def __init__(self): + super(FiredrakeTorchOperator, self).__init__() + # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') @staticmethod def forward(ctx, metadata, *ω): diff --git a/tests/conftest.py b/tests/conftest.py index 2d8a620261..afd54c0a18 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "skipcomplexnoslate: mark as skipped in complex mode due to lack of Slate") + config.addinivalue_line( + "markers", + "skiptorch: mark as skipped if PyTorch is not installed") @pytest.fixture(autouse=True) @@ -79,6 +82,10 @@ def pytest_runtest_call(item): def pytest_collection_modifyitems(session, config, items): from firedrake.utils import SLATE_SUPPORTS_COMPLEX + from firedrake.preconditioners.pytorch_coupling import get_backend + + backend = get_backend("pytorch") + for item in items: if complex_mode: if item.get_closest_marker("skipcomplex") is not None: @@ -89,6 +96,10 @@ def pytest_collection_modifyitems(session, config, items): if item.get_closest_marker("skipreal") is not None: item.add_marker(pytest.mark.skip(reason="Test makes no sense unless in complex mode")) + if not backend: + if item.get_closest_marker("skiptorch") is not None: + item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed")) + @pytest.fixture(scope="module", autouse=True) def check_empty_tape(request): diff --git a/tests/regression/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py index 36a626f7b7..83225c3c41 100644 --- a/tests/regression/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -1,14 +1,44 @@ import pytest -import torch -import torch.nn.functional as torch_func -from torch.nn import Module, Flatten, Linear - from firedrake import * from firedrake_adjoint import * from pyadjoint.tape import get_working_tape, pause_annotation +pytorch_backend = get_backend("pytorch") + +if pytorch_backend: + # PyTorch is installed + import torch + import torch.nn.functional as torch_func + from torch.nn import Module, Flatten, Linear + + class EncoderDecoder(Module): + """Build a simple toy model""" + + def __init__(self, n): + super(EncoderDecoder, self).__init__() + self.n = n + self.m = int(n/2) + self.flatten = Flatten() + self.linear_encoder = Linear(self.n, self.m) + self.linear_decoder = Linear(self.m, self.n) + + def encode(self, x): + return torch_func.relu(self.linear_encoder(x)) + + def decode(self, x): + return torch_func.relu(self.linear_decoder(x)) + + def forward(self, x): + # [batch_size, n] + x = self.flatten(x) + # [batch_size, m] + hidden = self.encode(x) + # [batch_size, n] + return self.decode(hidden) + + @pytest.fixture(autouse=True) def handle_taping(): yield @@ -45,32 +75,6 @@ def f_exact(V, mesh): return Function(V).interpolate(sin(pi * x) * sin(pi * y)) -class EncoderDecoder(Module): - """Build a simple toy model""" - - def __init__(self, n): - super(EncoderDecoder, self).__init__() - self.n = n - self.m = int(n/2) - self.flatten = Flatten() - self.linear_encoder = Linear(self.n, self.m) - self.linear_decoder = Linear(self.m, self.n) - - def encode(self, x): - return torch_func.relu(self.linear_encoder(x)) - - def decode(self, x): - return torch_func.relu(self.linear_decoder(x)) - - def forward(self, x): - # [batch_size, n] - x = self.flatten(x) - # [batch_size, m] - hidden = self.encode(x) - # [batch_size, n] - return self.decode(hidden) - - # Set of Firedrake operations that will be composed with PyTorch operations def poisson_residual(u, f, V): """Assemble the residual of a Poisson problem""" @@ -102,6 +106,7 @@ def firedrake_operator(request, f_exact, V): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skiptorch # Skip if PyTorch is not installed def test_pytorch_loss_backward(V, f_exact): """Test backpropagation through a vector-valued Firedrake operator""" @@ -114,9 +119,6 @@ def test_pytorch_loss_backward(V, f_exact): # Check that gradients are initially set to None assert all([θi.grad is None for θi in model.parameters()]) - # Get machine learning backend (default: PyTorch) - pytorch_backend = get_backend() - # Convert f_exact to torch.Tensor f_P = pytorch_backend.to_ml_backend(f_exact) @@ -156,6 +158,7 @@ def test_pytorch_loss_backward(V, f_exact): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skiptorch # Skip if PyTorch is not installed def test_firedrake_loss_backward(V): """Test backpropagation through a scalar-valued Firedrake operator""" @@ -168,9 +171,6 @@ def test_firedrake_loss_backward(V): # Check that gradients are initially set to None assert all([θi.grad is None for θi in model.parameters()]) - # Get machine learning backend (default: PyTorch) - pytorch_backend = get_backend() - # Model input λ = Function(V) @@ -211,6 +211,7 @@ def test_firedrake_loss_backward(V): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.skiptorch # Skip if PyTorch is not installed def test_taylor_torch_operator(firedrake_operator, V): """Taylor test for the torch operator""" # Control value From 369c070fc5c4f9d98f5b32fd7bffc66d3652e80a Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 14:11:51 +0000 Subject: [PATCH 23/48] Clean up --- firedrake/__init__.py | 1 + firedrake/preconditioners/__init__.py | 1 - .../pytorch_coupling/__init__.py | 2 - firedrake/pytorch_coupling/__init__.py | 2 + .../pytorch_coupling/backends.py | 38 +++++++++++-------- .../pytorch_custom_operator.py | 2 +- tests/conftest.py | 2 +- 7 files changed, 27 insertions(+), 21 deletions(-) delete mode 100644 firedrake/preconditioners/pytorch_coupling/__init__.py create mode 100644 firedrake/pytorch_coupling/__init__.py rename firedrake/{preconditioners => }/pytorch_coupling/backends.py (73%) rename firedrake/{preconditioners => }/pytorch_coupling/pytorch_custom_operator.py (97%) diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 3fd3eb910c..69850dbe69 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,6 +103,7 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 +from firedrake.pytorch_coupling import * from firedrake.logging import * # Set default log level diff --git a/firedrake/preconditioners/__init__.py b/firedrake/preconditioners/__init__.py index a1f6a0a110..e63a3c38f1 100644 --- a/firedrake/preconditioners/__init__.py +++ b/firedrake/preconditioners/__init__.py @@ -11,4 +11,3 @@ from firedrake.preconditioners.hypre_ads import * # noqa: F401 from firedrake.preconditioners.fdm import * # noqa: F401 from firedrake.preconditioners.facet_split import * # noqa: F401 -from firedrake.preconditioners.pytorch_coupling import * # noqa: F401 diff --git a/firedrake/preconditioners/pytorch_coupling/__init__.py b/firedrake/preconditioners/pytorch_coupling/__init__.py deleted file mode 100644 index e15b5199f7..0000000000 --- a/firedrake/preconditioners/pytorch_coupling/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .backends import get_backend # noqa: F401 -from .pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/firedrake/pytorch_coupling/__init__.py b/firedrake/pytorch_coupling/__init__.py new file mode 100644 index 0000000000..9db6ef30c8 --- /dev/null +++ b/firedrake/pytorch_coupling/__init__.py @@ -0,0 +1,2 @@ +from firedrake.pytorch_coupling.backends import get_backend # noqa: F401 +from firedrake.pytorch_coupling.pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/firedrake/preconditioners/pytorch_coupling/backends.py b/firedrake/pytorch_coupling/backends.py similarity index 73% rename from firedrake/preconditioners/pytorch_coupling/backends.py rename to firedrake/pytorch_coupling/backends.py index cf3f81f6a3..5278ab61fe 100644 --- a/firedrake/preconditioners/pytorch_coupling/backends.py +++ b/firedrake/pytorch_coupling/backends.py @@ -1,5 +1,3 @@ -import numpy as np - from firedrake.function import Function from firedrake.vector import Vector from firedrake.constant import Constant @@ -13,14 +11,16 @@ def backend(self): raise NotImplementedError def to_ml_backend(self, x): - """Convert from Firedrake to ML backend - x: Firedrake object + r"""Convert from Firedrake to ML backend. + + :arg x: Firedrake object """ raise NotImplementedError def from_ml_backend(self, x, V): - """Convert from ML backend to Firedrake - x: ML backend object + r"""Convert from ML backend to Firedrake. + + :arg x: ML backend object """ raise NotImplementedError @@ -57,23 +57,21 @@ def __bool__(self): @utils.cached_property def custom_operator(self): - from firedrake.preconditioners.pytorch_coupling.pytorch_custom_operator import FiredrakeTorchOperator + from firedrake.pytorch_coupling.pytorch_custom_operator import FiredrakeTorchOperator return FiredrakeTorchOperator().apply def to_ml_backend(self, x, gather=False, batched=True, **kwargs): - """ Convert a Firedrake object `x` into a PyTorch tensor + r"""Convert a Firedrake object `x` into a PyTorch tensor. - x: Firedrake object (Function, Vector, Constant) - gather: if True, gather data from all processes - batched: if True, add a batch dimension to the tensor - kwargs: additional arguments to be passed to torch.Tensor constructor + :arg x: Firedrake object (Function, Vector, Constant) + :kwarg gather: if True, gather data from all processes + :kwarg batched: if True, add a batch dimension to the tensor + :kwarg kwargs: additional arguments to be passed to torch.Tensor constructor - device: device on which the tensor is allocated (default: "cpu") - dtype: the desired data type of returned tensor (default: type of x.dat.data) - requires_grad: if the tensor should be annotated (default: False) """ if isinstance(x, (Function, Vector)): - # State counter: get_local does a copy and increase the state counter while gather does not. - # We probably always want to increase the state counter and therefore should do something for the gather case if gather: # Gather data from all processes x_P = self.backend.tensor(x.vector().gather(), **kwargs) @@ -87,11 +85,19 @@ def to_ml_backend(self, x, gather=False, batched=True, **kwargs): elif isinstance(x, Constant): return self.backend.tensor(x.values(), **kwargs) elif isinstance(x, (float, int)): - return self.backend.tensor(np.array(x), **kwargs) + if isinstance(x, float): + # Set double-precision + kwargs['dtype'] = self.backend.double + return self.backend.tensor(x, **kwargs) else: raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) - def from_ml_backend(self, x, V=None, gather=False): + def from_ml_backend(self, x, V=None): + r"""Convert a PyTorch tensor `x` into a Firedrake object. + + :arg x: PyTorch tensor (torch.Tensor) + :kwarg V: function space of the corresponding Function or None when `x` is to be mapped to a Constant + """ if x.device.type != "cpu": raise NotImplementedError("Firedrake does not support GPU tensors") diff --git a/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py b/firedrake/pytorch_coupling/pytorch_custom_operator.py similarity index 97% rename from firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py rename to firedrake/pytorch_coupling/pytorch_custom_operator.py index abe478ae09..b6dd902db9 100644 --- a/firedrake/preconditioners/pytorch_coupling/pytorch_custom_operator.py +++ b/firedrake/pytorch_coupling/pytorch_custom_operator.py @@ -1,7 +1,7 @@ import collections from functools import partial -from firedrake.preconditioners.pytorch_coupling import get_backend +from firedrake.pytorch_coupling import get_backend from firedrake.function import Function from pyadjoint.reduced_functional import ReducedFunctional diff --git a/tests/conftest.py b/tests/conftest.py index afd54c0a18..52943fbb55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,7 +82,7 @@ def pytest_runtest_call(item): def pytest_collection_modifyitems(session, config, items): from firedrake.utils import SLATE_SUPPORTS_COMPLEX - from firedrake.preconditioners.pytorch_coupling import get_backend + from firedrake.pytorch_coupling import get_backend backend = get_backend("pytorch") From 8764522011a064b9d799fa53e452828d6ab6e801 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 8 Mar 2023 15:56:37 +0000 Subject: [PATCH 24/48] Add pyadjoint branch --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 57422319a2..6655b567c5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,7 +48,7 @@ jobs: - name: Build Firedrake run: | cd .. - ./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false) + ./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager --package-branch pyadjoint adjoint-1-forms || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | . ../build/bin/activate From 8bde10ba04bba19a852ce05792683989c7428930 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 9 Mar 2023 09:21:11 +0000 Subject: [PATCH 25/48] Address PR's comments --- firedrake/__init__.py | 2 +- firedrake/ml_coupling/__init__.py | 2 + firedrake/ml_coupling/backend_base.py | 41 +++++++++++++++++++ .../pytorch/backend.py} | 41 +------------------ .../pytorch}/pytorch_custom_operator.py | 32 +++++++-------- firedrake/pytorch_coupling/__init__.py | 2 - tests/conftest.py | 6 +-- tests/regression/test_pytorch_coupling.py | 12 ++++-- 8 files changed, 74 insertions(+), 64 deletions(-) create mode 100644 firedrake/ml_coupling/__init__.py create mode 100644 firedrake/ml_coupling/backend_base.py rename firedrake/{pytorch_coupling/backends.py => ml_coupling/pytorch/backend.py} (72%) rename firedrake/{pytorch_coupling => ml_coupling/pytorch}/pytorch_custom_operator.py (74%) delete mode 100644 firedrake/pytorch_coupling/__init__.py diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 69850dbe69..ddc318c04b 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,7 +103,7 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 -from firedrake.pytorch_coupling import * +from firedrake.ml_coupling import * from firedrake.logging import * # Set default log level diff --git a/firedrake/ml_coupling/__init__.py b/firedrake/ml_coupling/__init__.py new file mode 100644 index 0000000000..b6e2cd7195 --- /dev/null +++ b/firedrake/ml_coupling/__init__.py @@ -0,0 +1,2 @@ +from firedrake.ml_coupling.backend_base import load_backend # noqa: F401 +from firedrake.ml_coupling.pytorch.pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/firedrake/ml_coupling/backend_base.py b/firedrake/ml_coupling/backend_base.py new file mode 100644 index 0000000000..d6acec3a72 --- /dev/null +++ b/firedrake/ml_coupling/backend_base.py @@ -0,0 +1,41 @@ +from firedrake.function import Function +from firedrake.vector import Vector + + +class AbstractMLBackend(object): + + def backend(self): + raise NotImplementedError + + def to_ml_backend(self, x): + r"""Convert from Firedrake to ML backend. + + :arg x: Firedrake object + """ + raise NotImplementedError + + def from_ml_backend(self, x, V): + r"""Convert from ML backend to Firedrake. + + :arg x: ML backend object + """ + raise NotImplementedError + + def function_space(self, x): + """Get function space out of x""" + if isinstance(x, Function): + return x.function_space() + elif isinstance(x, Vector): + return self.function_space(x.function) + elif isinstance(x, float): + return None + else: + raise ValueError("Cannot infer the function space of %s" % x) + + +def load_backend(backend_name='pytorch'): + if backend_name == 'pytorch': + from firedrake.ml_coupling.pytorch.backend import PytorchBackend + return PytorchBackend() + else: + raise NotImplementedError("The backend: %s is not supported." % backend_name) diff --git a/firedrake/pytorch_coupling/backends.py b/firedrake/ml_coupling/pytorch/backend.py similarity index 72% rename from firedrake/pytorch_coupling/backends.py rename to firedrake/ml_coupling/pytorch/backend.py index 5278ab61fe..0da4d557f8 100644 --- a/firedrake/pytorch_coupling/backends.py +++ b/firedrake/ml_coupling/pytorch/backend.py @@ -1,41 +1,11 @@ from firedrake.function import Function from firedrake.vector import Vector from firedrake.constant import Constant +from firedrake.ml_coupling.backend_base import AbstractMLBackend import firedrake.utils as utils -class AbstractMLBackend(object): - - def backend(self): - raise NotImplementedError - - def to_ml_backend(self, x): - r"""Convert from Firedrake to ML backend. - - :arg x: Firedrake object - """ - raise NotImplementedError - - def from_ml_backend(self, x, V): - r"""Convert from ML backend to Firedrake. - - :arg x: ML backend object - """ - raise NotImplementedError - - def get_function_space(self, x): - """Get function space out of x""" - if isinstance(x, Function): - return x.function_space() - elif isinstance(x, Vector): - return self.get_function_space(x.function) - elif isinstance(x, float): - return None - else: - raise ValueError("Cannot infer the function space of %s" % x) - - class PytorchBackend(AbstractMLBackend): def __init__(self): @@ -57,7 +27,7 @@ def __bool__(self): @utils.cached_property def custom_operator(self): - from firedrake.pytorch_coupling.pytorch_custom_operator import FiredrakeTorchOperator + from firedrake.ml_coupling.pytorch.pytorch_custom_operator import FiredrakeTorchOperator return FiredrakeTorchOperator().apply def to_ml_backend(self, x, gather=False, batched=True, **kwargs): @@ -111,10 +81,3 @@ def from_ml_backend(self, x, V=None): x_F = Function(V, dtype=x.dtype) x_F.vector().set_local(x) return x_F - - -def get_backend(backend_name='pytorch'): - if backend_name == 'pytorch': - return PytorchBackend() - else: - raise NotImplementedError("The backend: %s is not supported." % backend_name) diff --git a/firedrake/pytorch_coupling/pytorch_custom_operator.py b/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py similarity index 74% rename from firedrake/pytorch_coupling/pytorch_custom_operator.py rename to firedrake/ml_coupling/pytorch/pytorch_custom_operator.py index b6dd902db9..28eb32dc85 100644 --- a/firedrake/pytorch_coupling/pytorch_custom_operator.py +++ b/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py @@ -1,19 +1,19 @@ import collections from functools import partial -from firedrake.pytorch_coupling import get_backend +from firedrake.ml_coupling import load_backend from firedrake.function import Function from pyadjoint.reduced_functional import ReducedFunctional -backend = get_backend("pytorch") +backend = load_backend("pytorch") if backend: # PyTorch is installed BackendFunction = backend.backend.autograd.Function else: - class BackendFunction(object): + class BackendFunction(): """Dummy class that exceptions on instantiation.""" def __init__(self): raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") @@ -27,10 +27,10 @@ class FiredrakeTorchOperator(BackendFunction): Inputs: metadata: dictionary used to stash Firedrake objects. - *ω: PyTorch tensors representing the inputs to the Firedrake operator F + *x_P: PyTorch tensors representing the inputs to the Firedrake operator F Outputs: - y: PyTorch tensor representing the output of the Firedrake operator F + y_P: PyTorch tensor representing the output of the Firedrake operator F """ def __init__(self): @@ -38,19 +38,19 @@ def __init__(self): # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') @staticmethod - def forward(ctx, metadata, *ω): + def forward(ctx, metadata, *x_P): """Forward pass of the PyTorch custom operator.""" F = metadata['F'] V = metadata['V_controls'] # Convert PyTorch input (i.e. controls) to Firedrake - ω_F = [backend.from_ml_backend(ωi, Vi) for ωi, Vi in zip(ω, V)] + x_F = [backend.from_ml_backend(xi, Vi) for xi, Vi in zip(x_P, V)] # Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape - y_F = F(ω_F) + y_F = F(x_F) # Stash metadata to the PyTorch context ctx.metadata.update(metadata) # Convert Firedrake output to PyTorch - y = backend.to_ml_backend(y_F) - return y.detach() + y_P = backend.to_ml_backend(y_F) + return y_P.detach() @staticmethod def backward(ctx, grad_output): @@ -63,13 +63,13 @@ def backward(ctx, grad_output): adj_input = adj_input.vector() # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - Δω = F.derivative(adj_input=adj_input) + adj_output = F.derivative(adj_input=adj_input) # Tuplify adjoint output - Δω = (Δω,) if not isinstance(Δω, collections.abc.Sequence) else Δω + adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output # None is for metadata arg in `forward` - return None, *[backend.to_ml_backend(Δωi) for Δωi in Δω] + return None, *[backend.to_ml_backend(di) for di in adj_output] def torch_operator(F): @@ -79,8 +79,8 @@ def torch_operator(F): if not isinstance(F, ReducedFunctional): raise ValueError("F must be a ReducedFunctional") - V_output = backend.get_function_space(F.functional) + V_output = backend.function_space(F.functional) V_controls = [c.control.function_space() for c in F.controls] metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output} - φ = partial(backend.custom_operator, metadata) - return φ + F_P = partial(backend.custom_operator, metadata) + return F_P diff --git a/firedrake/pytorch_coupling/__init__.py b/firedrake/pytorch_coupling/__init__.py deleted file mode 100644 index 9db6ef30c8..0000000000 --- a/firedrake/pytorch_coupling/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from firedrake.pytorch_coupling.backends import get_backend # noqa: F401 -from firedrake.pytorch_coupling.pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/tests/conftest.py b/tests/conftest.py index 52943fbb55..1e4b0867f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,9 +82,9 @@ def pytest_runtest_call(item): def pytest_collection_modifyitems(session, config, items): from firedrake.utils import SLATE_SUPPORTS_COMPLEX - from firedrake.pytorch_coupling import get_backend + from firedrake.ml_coupling import load_backend - backend = get_backend("pytorch") + ml_backend = load_backend("pytorch") for item in items: if complex_mode: @@ -96,7 +96,7 @@ def pytest_collection_modifyitems(session, config, items): if item.get_closest_marker("skipreal") is not None: item.add_marker(pytest.mark.skip(reason="Test makes no sense unless in complex mode")) - if not backend: + if not ml_backend: if item.get_closest_marker("skiptorch") is not None: item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed")) diff --git a/tests/regression/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py index 83225c3c41..bf82ed3d31 100644 --- a/tests/regression/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -1,11 +1,10 @@ import pytest from firedrake import * -from firedrake_adjoint import * from pyadjoint.tape import get_working_tape, pause_annotation -pytorch_backend = get_backend("pytorch") +pytorch_backend = load_backend("pytorch") if pytorch_backend: # PyTorch is installed @@ -110,6 +109,8 @@ def firedrake_operator(request, f_exact, V): def test_pytorch_loss_backward(V, f_exact): """Test backpropagation through a vector-valued Firedrake operator""" + from firedrake_adjoint import ReducedFunctional, Control + # Instantiate model model = EncoderDecoder(V.dim()) @@ -162,6 +163,8 @@ def test_pytorch_loss_backward(V, f_exact): def test_firedrake_loss_backward(V): """Test backpropagation through a scalar-valued Firedrake operator""" + from firedrake_adjoint import ReducedFunctional, Control + # Instantiate model model = EncoderDecoder(V.dim()) @@ -214,6 +217,9 @@ def test_firedrake_loss_backward(V): @pytest.mark.skiptorch # Skip if PyTorch is not installed def test_taylor_torch_operator(firedrake_operator, V): """Taylor test for the torch operator""" + + from firedrake_adjoint import ReducedFunctional, Control + # Control value ω = Function(V) # Get Firedrake operator and other operator arguments @@ -222,7 +228,7 @@ def test_taylor_torch_operator(firedrake_operator, V): Jhat = ReducedFunctional(fd_op(ω, *args), Control(ω)) # Define the torch operator G = torch_operator(Jhat) - # `gradcheck` is likey to fail if the inputs are not double precision (cf. https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.html) + # `gradcheck` is likely to fail if the inputs are not double precision (cf. https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.html) x_P = torch.rand(V.dim(), dtype=torch.double, requires_grad=True) # Taylor test (`eps` is the perturbation) assert torch.autograd.gradcheck(G, x_P, eps=1e-6) From d520c28f3fb66473e2ab86e26f8e791c0877e0fb Mon Sep 17 00:00:00 2001 From: Nacime Bouziani <48448063+nbouziani@users.noreply.github.com> Date: Thu, 9 Mar 2023 12:47:09 +0000 Subject: [PATCH 26/48] Update build.yml --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6655b567c5..57422319a2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,7 +48,7 @@ jobs: - name: Build Firedrake run: | cd .. - ./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager --package-branch pyadjoint adjoint-1-forms || (cat firedrake-install.log && /bin/false) + ./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | . ../build/bin/activate From 97d1001f3a9c1908f5b177bf0eef6ae78f8137fc Mon Sep 17 00:00:00 2001 From: nbouziani Date: Fri, 10 Mar 2023 14:03:15 +0000 Subject: [PATCH 27/48] Add option to install torch --- .github/workflows/build.yml | 3 +-- scripts/firedrake-install | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 57422319a2..63e46541ca 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,14 +48,13 @@ jobs: - name: Build Firedrake run: | cd .. - ./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false) + ./firedrake/scripts/firedrake-install $COMPLEX --venv-name build --tinyasm --torch --disable-ssh --minimal-petsc --slepc --documentation-dependencies --install thetis --install gusto --install icepack --install irksome --install femlium --no-package-manager || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | . ../build/bin/activate python $(which firedrake-clean) python -m pip install pytest-cov pytest-timeout pytest-xdist pytest-timeout python -m pip list - python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu - name: Test Firedrake run: | . ../build/bin/activate diff --git a/scripts/firedrake-install b/scripts/firedrake-install index 2f70bf9762..070c6eecea 100755 --- a/scripts/firedrake-install +++ b/scripts/firedrake-install @@ -68,7 +68,7 @@ class FiredrakeConfiguration(dict): "minimal_petsc", "mpicc", "mpicxx", "mpif90", "mpiexec", "disable_ssh", "honour_petsc_dir", "with_parmetis", "slepc", "packages", "honour_pythonpath", - "opencascade", "tinyasm", + "opencascade", "tinyasm", "torch", "petsc_int_type", "cache_dir", "complex", "remove_build_files", "with_blas"] @@ -259,6 +259,8 @@ honoured.""", help="Install OpenCASCADE for CAD integration.") parser.add_argument("--tinyasm", action="store_true", help="Install TinyASM as backend for ASMPatchPC.") + parser.add_argument("--torch", const="cpu", default=False, nargs='?', choices=["cpu", "cuda"], + help="Install PyTorch") parser.add_argument("--disable-ssh", action="store_true", help="Do not attempt to use ssh to clone git repositories: fall immediately back to https.") parser.add_argument("--no-package-manager", action='store_false', dest="package_manager", @@ -393,6 +395,8 @@ else: help="Install OpenCASCADE for CAD integration.") parser.add_argument("--tinyasm", action="store_true", dest="tinyasm", default=config["options"].get("tinyasm", False), help="Install TinyASM as backend for ASMPatchPC.") + parser.add_argument("--torch", const="cpu", nargs='?', choices=["cpu", "cuda"], default=config["options"].get("torch", False), + help="Install PyTorch") parser.add_argument("--honour-petsc-dir", action="store_true", default=config["options"]["honour_petsc_dir"], help="Usually it is best to let Firedrake build its own PETSc. If you wish to use another PETSc, set PETSC_DIR and pass this option.") @@ -1334,6 +1338,12 @@ def build_and_install_pythonocc(): log.info("No need to rebuild pythonocc-core") +def build_and_install_torch(): + log.info("Installing PyTorch") + extra_index_url = ["--extra-index-url", "https://download.pytorch.org/whl/cpu"] if args.torch == "cpu" else [] + run_pip_install(["torch"] + extra_index_url) + + def build_and_install_libspatialindex(): log.info("Installing libspatialindex") if os.path.exists("libspatialindex"): @@ -2058,6 +2068,10 @@ with pipargs("--no-deps"): if options["opencascade"]: build_and_install_pythonocc() +with pipargs("--no-deps"): + if options["torch"]: + build_and_install_torch() + if args.documentation_dependencies: install_documentation_dependencies() From 125cff961b8a68489d30fb81b8fa550bd160e1e2 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Fri, 10 Mar 2023 14:22:32 +0000 Subject: [PATCH 28/48] Add doc for torch installation --- scripts/firedrake-install | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/firedrake-install b/scripts/firedrake-install index 070c6eecea..097965d03f 100755 --- a/scripts/firedrake-install +++ b/scripts/firedrake-install @@ -260,7 +260,7 @@ honoured.""", parser.add_argument("--tinyasm", action="store_true", help="Install TinyASM as backend for ASMPatchPC.") parser.add_argument("--torch", const="cpu", default=False, nargs='?', choices=["cpu", "cuda"], - help="Install PyTorch") + help="Install PyTorch for a CPU or CUDA backend (default: CPU).") parser.add_argument("--disable-ssh", action="store_true", help="Do not attempt to use ssh to clone git repositories: fall immediately back to https.") parser.add_argument("--no-package-manager", action='store_false', dest="package_manager", @@ -396,7 +396,7 @@ else: parser.add_argument("--tinyasm", action="store_true", dest="tinyasm", default=config["options"].get("tinyasm", False), help="Install TinyASM as backend for ASMPatchPC.") parser.add_argument("--torch", const="cpu", nargs='?', choices=["cpu", "cuda"], default=config["options"].get("torch", False), - help="Install PyTorch") + help="Install PyTorch for a CPU or CUDA backend (default: CPU).") parser.add_argument("--honour-petsc-dir", action="store_true", default=config["options"]["honour_petsc_dir"], help="Usually it is best to let Firedrake build its own PETSc. If you wish to use another PETSc, set PETSC_DIR and pass this option.") @@ -1339,7 +1339,10 @@ def build_and_install_pythonocc(): def build_and_install_torch(): - log.info("Installing PyTorch") + """Install PyTorch for a CPU or CUDA backend.""" + log.info("Installing PyTorch (backend: %s)" % args.torch) + if osname == "Darwin" and args.torch == "cuda": + raise InstallError("CUDA installation is not available on MacOS.") extra_index_url = ["--extra-index-url", "https://download.pytorch.org/whl/cpu"] if args.torch == "cpu" else [] run_pip_install(["torch"] + extra_index_url) From 8d75d6dfbe280ef51d1a5ec32afd3ecafb58fe30 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Tue, 14 Mar 2023 13:49:25 +0000 Subject: [PATCH 29/48] Add citation --- .../ml_coupling/pytorch/pytorch_custom_operator.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py b/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py index 28eb32dc85..b0e1d79aef 100644 --- a/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py +++ b/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py @@ -3,10 +3,22 @@ from firedrake.ml_coupling import load_backend from firedrake.function import Function +from firedrake_citations import Citations from pyadjoint.reduced_functional import ReducedFunctional +Citations().add("Bouziani2023", """ +@inproceedings{Bouziani2023, + title = {Physics-driven machine learning models coupling {PyTorch} and {Firedrake}}, + author = {Bouziani, Nacime and Ham, David A.}, + booktitle = {{ICLR} 2023 {Workshop} on {Physics} for {Machine} {Learning}}, + year = {2023}, + doi = {10.48550/arXiv.2303.06871} +} +""") + + backend = load_backend("pytorch") if backend: @@ -76,6 +88,8 @@ def torch_operator(F): """Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator whose inputs and outputs are PyTorch tensors. """ + Citations().register("Bouziani2023") + if not isinstance(F, ReducedFunctional): raise ValueError("F must be a ReducedFunctional") From 01627ecbf5f52092a12268a4d864fe7e4dfae19f Mon Sep 17 00:00:00 2001 From: nbouziani Date: Fri, 17 Mar 2023 15:34:54 +0000 Subject: [PATCH 30/48] Remove non ascii characters --- firedrake/ml_coupling/pytorch/backend.py | 7 ++----- tests/regression/test_pytorch_coupling.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/firedrake/ml_coupling/pytorch/backend.py b/firedrake/ml_coupling/pytorch/backend.py index 0da4d557f8..6754c39ad4 100644 --- a/firedrake/ml_coupling/pytorch/backend.py +++ b/firedrake/ml_coupling/pytorch/backend.py @@ -3,8 +3,6 @@ from firedrake.constant import Constant from firedrake.ml_coupling.backend_base import AbstractMLBackend -import firedrake.utils as utils - class PytorchBackend(AbstractMLBackend): @@ -25,10 +23,9 @@ def backend(self): def __bool__(self): return self._backend is not None - @utils.cached_property - def custom_operator(self): + def custom_operator(self, *args, **kwargs): from firedrake.ml_coupling.pytorch.pytorch_custom_operator import FiredrakeTorchOperator - return FiredrakeTorchOperator().apply + return FiredrakeTorchOperator.apply(*args, **kwargs) def to_ml_backend(self, x, gather=False, batched=True, **kwargs): r"""Convert a Firedrake object `x` into a PyTorch tensor. diff --git a/tests/regression/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py index bf82ed3d31..1e1ed9f112 100644 --- a/tests/regression/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -118,7 +118,7 @@ def test_pytorch_loss_backward(V, f_exact): model.double() # Check that gradients are initially set to None - assert all([θi.grad is None for θi in model.parameters()]) + assert all([pi.grad is None for pi in model.parameters()]) # Convert f_exact to torch.Tensor f_P = pytorch_backend.to_ml_backend(f_exact) @@ -148,7 +148,7 @@ def test_pytorch_loss_backward(V, f_exact): # Check that gradients were propagated to model parameters # This test doesn't check the correctness of these gradients # -> This is checked in `test_taylor_torch_operator` - assert all([θi.grad is not None for θi in model.parameters()]) + assert all([pi.grad is not None for pi in model.parameters()]) # -- Check forward operator -- # u = pytorch_backend.from_ml_backend(u_P, V) @@ -172,16 +172,16 @@ def test_firedrake_loss_backward(V): model.double() # Check that gradients are initially set to None - assert all([θi.grad is None for θi in model.parameters()]) + assert all([pi.grad is None for pi in model.parameters()]) # Model input - λ = Function(V) + u = Function(V) # Convert f to torch.Tensor - λ_P = pytorch_backend.to_ml_backend(λ) + u_P = pytorch_backend.to_ml_backend(u) # Forward pass - f_P = model(λ_P) + f_P = model(u_P) # Set control f = Function(V) @@ -203,7 +203,7 @@ def test_firedrake_loss_backward(V): # Check that gradients were propagated to model parameters # This test doesn't check the correctness of these gradients # -> This is checked in `test_taylor_torch_operator` - assert all([θi.grad is not None for θi in model.parameters()]) + assert all([pi.grad is not None for pi in model.parameters()]) # -- Check forward operator -- # f = pytorch_backend.from_ml_backend(f_P, V) @@ -221,11 +221,11 @@ def test_taylor_torch_operator(firedrake_operator, V): from firedrake_adjoint import ReducedFunctional, Control # Control value - ω = Function(V) + w = Function(V) # Get Firedrake operator and other operator arguments fd_op, args = firedrake_operator # Set reduced functional - Jhat = ReducedFunctional(fd_op(ω, *args), Control(ω)) + Jhat = ReducedFunctional(fd_op(w, *args), Control(w)) # Define the torch operator G = torch_operator(Jhat) # `gradcheck` is likely to fail if the inputs are not double precision (cf. https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.html) From aef4c11f23399ccc4c07d3f2cc166ff5ddd44987 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 29 Mar 2023 16:52:36 +0100 Subject: [PATCH 31/48] Rename backend function --- firedrake/ml_coupling/pytorch/pytorch_custom_operator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py b/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py index b0e1d79aef..bc3522e557 100644 --- a/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py +++ b/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py @@ -23,15 +23,15 @@ if backend: # PyTorch is installed - BackendFunction = backend.backend.autograd.Function + PytorchFunction = backend.backend.autograd.Function else: - class BackendFunction(): + class PytorchFunction(): """Dummy class that exceptions on instantiation.""" def __init__(self): raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") -class FiredrakeTorchOperator(BackendFunction): +class FiredrakeTorchOperator(PytorchFunction): """ PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. `FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward From a417a39fb57a159ce39406722176134d17bf4195 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sat, 1 Apr 2023 11:14:31 +0100 Subject: [PATCH 32/48] Lift torch installation consideration to the import level --- firedrake/__init__.py | 1 - firedrake/ml/__init__.py | 2 ++ firedrake/{ml_coupling => ml}/backend_base.py | 2 +- firedrake/ml/pytorch/__init__.py | 7 ++++++ .../{ml_coupling => ml}/pytorch/backend.py | 24 ++++++------------- .../pytorch/pytorch_custom_operator.py | 14 +++-------- firedrake/ml_coupling/__init__.py | 2 -- tests/conftest.py | 8 +++++-- tests/regression/test_pytorch_coupling.py | 11 +++++---- 9 files changed, 33 insertions(+), 38 deletions(-) create mode 100644 firedrake/ml/__init__.py rename firedrake/{ml_coupling => ml}/backend_base.py (93%) create mode 100644 firedrake/ml/pytorch/__init__.py rename firedrake/{ml_coupling => ml}/pytorch/backend.py (83%) rename firedrake/{ml_coupling => ml}/pytorch/pytorch_custom_operator.py (88%) delete mode 100644 firedrake/ml_coupling/__init__.py diff --git a/firedrake/__init__.py b/firedrake/__init__.py index ddc318c04b..3fd3eb910c 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,7 +103,6 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 -from firedrake.ml_coupling import * from firedrake.logging import * # Set default log level diff --git a/firedrake/ml/__init__.py b/firedrake/ml/__init__.py new file mode 100644 index 0000000000..7a28838ca5 --- /dev/null +++ b/firedrake/ml/__init__.py @@ -0,0 +1,2 @@ +from firedrake.ml.backend_base import load_backend # noqa: F401 +from firedrake.ml.pytorch import * # noqa: F401 diff --git a/firedrake/ml_coupling/backend_base.py b/firedrake/ml/backend_base.py similarity index 93% rename from firedrake/ml_coupling/backend_base.py rename to firedrake/ml/backend_base.py index d6acec3a72..4abba5e191 100644 --- a/firedrake/ml_coupling/backend_base.py +++ b/firedrake/ml/backend_base.py @@ -35,7 +35,7 @@ def function_space(self, x): def load_backend(backend_name='pytorch'): if backend_name == 'pytorch': - from firedrake.ml_coupling.pytorch.backend import PytorchBackend + from firedrake.ml.pytorch.backend import PytorchBackend return PytorchBackend() else: raise NotImplementedError("The backend: %s is not supported." % backend_name) diff --git a/firedrake/ml/pytorch/__init__.py b/firedrake/ml/pytorch/__init__.py new file mode 100644 index 0000000000..8b9915005f --- /dev/null +++ b/firedrake/ml/pytorch/__init__.py @@ -0,0 +1,7 @@ +try: + import torch + del torch +except ImportError: + raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") + +from firedrake.ml.pytorch.pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/firedrake/ml_coupling/pytorch/backend.py b/firedrake/ml/pytorch/backend.py similarity index 83% rename from firedrake/ml_coupling/pytorch/backend.py rename to firedrake/ml/pytorch/backend.py index 6754c39ad4..8e491a14b1 100644 --- a/firedrake/ml_coupling/pytorch/backend.py +++ b/firedrake/ml/pytorch/backend.py @@ -1,30 +1,20 @@ from firedrake.function import Function from firedrake.vector import Vector from firedrake.constant import Constant -from firedrake.ml_coupling.backend_base import AbstractMLBackend +from firedrake.ml.backend_base import AbstractMLBackend +import firedrake.utils as utils -class PytorchBackend(AbstractMLBackend): - def __init__(self): - try: - import torch - self._backend = torch - except ImportError: - self._backend = None +class PytorchBackend(AbstractMLBackend): - @property + @utils.cached_property def backend(self): - if self: - return self._backend - else: - raise ImportError("Error when trying to import PyTorch.") - - def __bool__(self): - return self._backend is not None + import torch + return torch def custom_operator(self, *args, **kwargs): - from firedrake.ml_coupling.pytorch.pytorch_custom_operator import FiredrakeTorchOperator + from firedrake.ml.pytorch.pytorch_custom_operator import FiredrakeTorchOperator return FiredrakeTorchOperator.apply(*args, **kwargs) def to_ml_backend(self, x, gather=False, batched=True, **kwargs): diff --git a/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py b/firedrake/ml/pytorch/pytorch_custom_operator.py similarity index 88% rename from firedrake/ml_coupling/pytorch/pytorch_custom_operator.py rename to firedrake/ml/pytorch/pytorch_custom_operator.py index bc3522e557..49628018ae 100644 --- a/firedrake/ml_coupling/pytorch/pytorch_custom_operator.py +++ b/firedrake/ml/pytorch/pytorch_custom_operator.py @@ -1,7 +1,8 @@ +import torch import collections from functools import partial -from firedrake.ml_coupling import load_backend +from firedrake.ml import load_backend from firedrake.function import Function from firedrake_citations import Citations @@ -21,17 +22,8 @@ backend = load_backend("pytorch") -if backend: - # PyTorch is installed - PytorchFunction = backend.backend.autograd.Function -else: - class PytorchFunction(): - """Dummy class that exceptions on instantiation.""" - def __init__(self): - raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") - -class FiredrakeTorchOperator(PytorchFunction): +class FiredrakeTorchOperator(torch.autograd.Function): """ PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. `FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward diff --git a/firedrake/ml_coupling/__init__.py b/firedrake/ml_coupling/__init__.py deleted file mode 100644 index b6e2cd7195..0000000000 --- a/firedrake/ml_coupling/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from firedrake.ml_coupling.backend_base import load_backend # noqa: F401 -from firedrake.ml_coupling.pytorch.pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/tests/conftest.py b/tests/conftest.py index eb04a72a1f..65d47ab038 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,9 +21,13 @@ def pytest_configure(config): def pytest_collection_modifyitems(session, config, items): from firedrake.utils import complex_mode, SLATE_SUPPORTS_COMPLEX - from firedrake.ml_coupling import load_backend - ml_backend = load_backend("pytorch") + try: + import firedrake.ml as fd_ml + del fd_ml + ml_backend = True + except ImportError: + ml_backend = False for item in items: if complex_mode: diff --git a/tests/regression/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py index 1e1ed9f112..862b76b984 100644 --- a/tests/regression/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -4,14 +4,14 @@ from pyadjoint.tape import get_working_tape, pause_annotation -pytorch_backend = load_backend("pytorch") - -if pytorch_backend: - # PyTorch is installed +try: + from firedrake.ml import load_backend, torch_operator import torch import torch.nn.functional as torch_func from torch.nn import Module, Flatten, Linear + pytorch_backend = load_backend("pytorch") + class EncoderDecoder(Module): """Build a simple toy model""" @@ -36,6 +36,9 @@ def forward(self, x): hidden = self.encode(x) # [batch_size, n] return self.decode(hidden) +except ImportError: + # PyTorch is not installed + pass @pytest.fixture(autouse=True) From 1ef391994831013c7b38dccb8193e21bb776064c Mon Sep 17 00:00:00 2001 From: Nacime Bouziani <48448063+nbouziani@users.noreply.github.com> Date: Sat, 1 Apr 2023 12:19:59 +0100 Subject: [PATCH 33/48] Add torch to vanilla docker container --- docker/Dockerfile.vanilla | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.vanilla b/docker/Dockerfile.vanilla index 1aa0624e71..bcefe10b34 100644 --- a/docker/Dockerfile.vanilla +++ b/docker/Dockerfile.vanilla @@ -10,4 +10,4 @@ WORKDIR /home/firedrake # Now install Firedrake. RUN curl -O https://raw.githubusercontent.com/firedrakeproject/firedrake/master/scripts/firedrake-install -RUN bash -c "PETSC_CONFIGURE_OPTIONS='--download-fftw=1' python3 firedrake-install --no-package-manager --disable-ssh --remove-build-files" +RUN bash -c "PETSC_CONFIGURE_OPTIONS='--download-fftw=1' python3 firedrake-install --no-package-manager --disable-ssh --remove-build-files --torch" From d51158a6385472c409c26f83f3526d9a38694383 Mon Sep 17 00:00:00 2001 From: Nacime Bouziani <48448063+nbouziani@users.noreply.github.com> Date: Mon, 3 Apr 2023 09:44:06 +0100 Subject: [PATCH 34/48] Update Dockerfile.vanilla --- docker/Dockerfile.vanilla | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.vanilla b/docker/Dockerfile.vanilla index bcefe10b34..14e6dde518 100644 --- a/docker/Dockerfile.vanilla +++ b/docker/Dockerfile.vanilla @@ -9,5 +9,5 @@ USER firedrake WORKDIR /home/firedrake # Now install Firedrake. -RUN curl -O https://raw.githubusercontent.com/firedrakeproject/firedrake/master/scripts/firedrake-install +RUN curl -O https://raw.githubusercontent.com/firedrakeproject/firedrake/pytorch_coupling/scripts/firedrake-install RUN bash -c "PETSC_CONFIGURE_OPTIONS='--download-fftw=1' python3 firedrake-install --no-package-manager --disable-ssh --remove-build-files --torch" From e5876c48e846f4832ba18484e1c6bd05d676d657 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Mon, 3 Apr 2023 12:35:23 +0100 Subject: [PATCH 35/48] Build (test) Firedrake docker image with torch --- .github/workflows/docs.yml | 21 +++++++++++++++++++++ docker/Dockerfile.firedrake | 2 +- docker/Dockerfile.vanilla | 4 ++-- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b0f5bbe93d..546b56756e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,8 +10,29 @@ on: jobs: + docker_test: + name: "Build Docker environment container" + runs-on: self-hosted + steps: + - name: Check out the repo + uses: actions/checkout@v2 + - name: Log in to Docker Hub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USER }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Build and push firedrake-env + uses: docker/build-push-action@v2 + with: + push: true + file: docker/Dockerfile.firedrake + tags: firedrakeproject/firedrake-test:latest build: name: "Run doc build" + if: false # The type of runner that the job will run on runs-on: ubuntu-latest # The docker container to use. diff --git a/docker/Dockerfile.firedrake b/docker/Dockerfile.firedrake index 5acba73b66..c5fe3be728 100644 --- a/docker/Dockerfile.firedrake +++ b/docker/Dockerfile.firedrake @@ -9,4 +9,4 @@ USER firedrake WORKDIR /home/firedrake # Now install extra Firedrake components. -RUN bash -c "source firedrake/bin/activate; firedrake-update --documentation-dependencies --tinyasm --slepc --install thetis --install gusto --install icepack --install irksome --install femlium" +RUN bash -c "source firedrake/bin/activate; firedrake-update --documentation-dependencies --tinyasm --slepc --install thetis --install gusto --install icepack --install irksome --install femlium --torch" diff --git a/docker/Dockerfile.vanilla b/docker/Dockerfile.vanilla index 14e6dde518..1aa0624e71 100644 --- a/docker/Dockerfile.vanilla +++ b/docker/Dockerfile.vanilla @@ -9,5 +9,5 @@ USER firedrake WORKDIR /home/firedrake # Now install Firedrake. -RUN curl -O https://raw.githubusercontent.com/firedrakeproject/firedrake/pytorch_coupling/scripts/firedrake-install -RUN bash -c "PETSC_CONFIGURE_OPTIONS='--download-fftw=1' python3 firedrake-install --no-package-manager --disable-ssh --remove-build-files --torch" +RUN curl -O https://raw.githubusercontent.com/firedrakeproject/firedrake/master/scripts/firedrake-install +RUN bash -c "PETSC_CONFIGURE_OPTIONS='--download-fftw=1' python3 firedrake-install --no-package-manager --disable-ssh --remove-build-files" From 6bf1e2ffd4187ec3fdf75b5ded45bd8abdae4e3c Mon Sep 17 00:00:00 2001 From: nbouziani Date: Mon, 3 Apr 2023 12:43:42 +0100 Subject: [PATCH 36/48] Use test container for building docs --- .github/workflows/docs.yml | 44 +++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 546b56756e..88dd19c4ae 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,34 +10,13 @@ on: jobs: - docker_test: - name: "Build Docker environment container" - runs-on: self-hosted - steps: - - name: Check out the repo - uses: actions/checkout@v2 - - name: Log in to Docker Hub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKERHUB_USER }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - id: buildx - uses: docker/setup-buildx-action@v1 - - name: Build and push firedrake-env - uses: docker/build-push-action@v2 - with: - push: true - file: docker/Dockerfile.firedrake - tags: firedrakeproject/firedrake-test:latest build: name: "Run doc build" - if: false # The type of runner that the job will run on runs-on: ubuntu-latest # The docker container to use. container: - image: firedrakeproject/firedrake-vanilla:latest + image: firedrakeproject/firedrake-test:latest options: --user root # Steps represent a sequence of tasks that will be executed as # part of the jobs @@ -73,3 +52,24 @@ jobs: make html make latex make latexpdf + docker_test: + name: "Build Docker environment container" + needs: build + runs-on: self-hosted + steps: + - name: Check out the repo + uses: actions/checkout@v2 + - name: Log in to Docker Hub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USER }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Build and push firedrake-env + uses: docker/build-push-action@v2 + with: + push: true + file: docker/Dockerfile.firedrake + tags: firedrakeproject/firedrake-test:latest From a465d9322ccab86e43022ab1245d44693a45dcf5 Mon Sep 17 00:00:00 2001 From: Nacime Bouziani <48448063+nbouziani@users.noreply.github.com> Date: Mon, 3 Apr 2023 14:58:03 +0100 Subject: [PATCH 37/48] Update docs.yml --- .github/workflows/docs.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 88dd19c4ae..9a093ba852 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -54,7 +54,6 @@ jobs: make latexpdf docker_test: name: "Build Docker environment container" - needs: build runs-on: self-hosted steps: - name: Check out the repo From d3448bd3e39e6602233ea17150b17e21be1d3413 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Tue, 16 May 2023 11:43:53 +0100 Subject: [PATCH 38/48] Revert container for docs --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index dacb4e99e5..36928f2bb0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest # The docker container to use. container: - image: firedrakeproject/firedrake-test:latest + image: firedrakeproject/firedrake-vanilla:latest options: --user root # Steps represent a sequence of tasks that will be executed as # part of the jobs From b4d7f4851ebc3924327c7275775253192eb62708 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Tue, 16 May 2023 14:58:54 +0100 Subject: [PATCH 39/48] Fix docs --- firedrake/ml/pytorch/pytorch_custom_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/ml/pytorch/pytorch_custom_operator.py b/firedrake/ml/pytorch/pytorch_custom_operator.py index 49628018ae..5682269a49 100644 --- a/firedrake/ml/pytorch/pytorch_custom_operator.py +++ b/firedrake/ml/pytorch/pytorch_custom_operator.py @@ -31,7 +31,7 @@ class FiredrakeTorchOperator(torch.autograd.Function): Inputs: metadata: dictionary used to stash Firedrake objects. - *x_P: PyTorch tensors representing the inputs to the Firedrake operator F + x_P: PyTorch tensors representing the inputs to the Firedrake operator F Outputs: y_P: PyTorch tensor representing the output of the Firedrake operator F From 9304bf5cd7e233476607435d997e7680656daa7e Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 17 May 2023 11:55:02 +0100 Subject: [PATCH 40/48] Reference class in FiredrakeTorchOperator's docstring --- firedrake/ml/pytorch/pytorch_custom_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/ml/pytorch/pytorch_custom_operator.py b/firedrake/ml/pytorch/pytorch_custom_operator.py index 5682269a49..addf4b1e30 100644 --- a/firedrake/ml/pytorch/pytorch_custom_operator.py +++ b/firedrake/ml/pytorch/pytorch_custom_operator.py @@ -26,7 +26,7 @@ class FiredrakeTorchOperator(torch.autograd.Function): """ PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. - `FiredrakeTorchOperator` is a wrapper around `torch.autograd.Function` that executes forward and backward + `FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward passes by directly calling the reduced functional F. Inputs: From 9fdcd1a988b6892b593b3644642f01d998592719 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Wed, 17 May 2023 14:54:23 +0100 Subject: [PATCH 41/48] Add torch doc link --- docs/source/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1d8cb7e0da..a68c504800 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -388,6 +388,7 @@ 'pyadjoint': ('https://www.dolfin-adjoint.org/en/latest/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'loopy': ('https://documen.tician.de/loopy/', None), + 'torch': ('https://pytorch.org/docs/stable/', None), } # -- Options for sphinxcontrib.bibtex ------------------------------------ From 7791fb5b45900f0aaf4d9ad9561ef1f2bd58a155 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 18 May 2023 11:26:20 +0100 Subject: [PATCH 42/48] Remove backend abstraction --- firedrake/ml/__init__.py | 3 +- firedrake/ml/backend_base.py | 41 ----- firedrake/ml/pytorch.py | 160 ++++++++++++++++++ firedrake/ml/pytorch/__init__.py | 7 - firedrake/ml/pytorch/backend.py | 70 -------- .../ml/pytorch/pytorch_custom_operator.py | 92 ---------- tests/regression/test_pytorch_coupling.py | 16 +- 7 files changed, 168 insertions(+), 221 deletions(-) delete mode 100644 firedrake/ml/backend_base.py create mode 100644 firedrake/ml/pytorch.py delete mode 100644 firedrake/ml/pytorch/__init__.py delete mode 100644 firedrake/ml/pytorch/backend.py delete mode 100644 firedrake/ml/pytorch/pytorch_custom_operator.py diff --git a/firedrake/ml/__init__.py b/firedrake/ml/__init__.py index 7a28838ca5..15ae7dfae2 100644 --- a/firedrake/ml/__init__.py +++ b/firedrake/ml/__init__.py @@ -1,2 +1 @@ -from firedrake.ml.backend_base import load_backend # noqa: F401 -from firedrake.ml.pytorch import * # noqa: F401 +from firedrake.ml.pytorch import FiredrakeTorchOperator, torch_operator, from_ml_backend, to_ml_backend # noqa: F401 diff --git a/firedrake/ml/backend_base.py b/firedrake/ml/backend_base.py deleted file mode 100644 index 4abba5e191..0000000000 --- a/firedrake/ml/backend_base.py +++ /dev/null @@ -1,41 +0,0 @@ -from firedrake.function import Function -from firedrake.vector import Vector - - -class AbstractMLBackend(object): - - def backend(self): - raise NotImplementedError - - def to_ml_backend(self, x): - r"""Convert from Firedrake to ML backend. - - :arg x: Firedrake object - """ - raise NotImplementedError - - def from_ml_backend(self, x, V): - r"""Convert from ML backend to Firedrake. - - :arg x: ML backend object - """ - raise NotImplementedError - - def function_space(self, x): - """Get function space out of x""" - if isinstance(x, Function): - return x.function_space() - elif isinstance(x, Vector): - return self.function_space(x.function) - elif isinstance(x, float): - return None - else: - raise ValueError("Cannot infer the function space of %s" % x) - - -def load_backend(backend_name='pytorch'): - if backend_name == 'pytorch': - from firedrake.ml.pytorch.backend import PytorchBackend - return PytorchBackend() - else: - raise NotImplementedError("The backend: %s is not supported." % backend_name) diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py new file mode 100644 index 0000000000..304c732a17 --- /dev/null +++ b/firedrake/ml/pytorch.py @@ -0,0 +1,160 @@ +try: + import torch +except ImportError: + raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") + +import collections +from functools import partial + +from firedrake.function import Function +from firedrake.vector import Vector +from firedrake.constant import Constant +from firedrake_citations import Citations + +from pyadjoint.reduced_functional import ReducedFunctional + + +Citations().add("Bouziani2023", """ +@inproceedings{Bouziani2023, + title = {Physics-driven machine learning models coupling {PyTorch} and {Firedrake}}, + author = {Bouziani, Nacime and Ham, David A.}, + booktitle = {{ICLR} 2023 {Workshop} on {Physics} for {Machine} {Learning}}, + year = {2023}, + doi = {10.48550/arXiv.2303.06871} +} +""") + + +class FiredrakeTorchOperator(torch.autograd.Function): + """ + PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. + `FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward + passes by directly calling the reduced functional F. + + Inputs: + metadata: dictionary used to stash Firedrake objects. + x_P: PyTorch tensors representing the inputs to the Firedrake operator F + + Outputs: + y_P: PyTorch tensor representing the output of the Firedrake operator F + """ + + def __init__(self): + super(FiredrakeTorchOperator, self).__init__() + + # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') + @staticmethod + def forward(ctx, metadata, *x_P): + """Forward pass of the PyTorch custom operator.""" + F = metadata['F'] + V = metadata['V_controls'] + # Convert PyTorch input (i.e. controls) to Firedrake + x_F = [from_ml_backend(xi, Vi) for xi, Vi in zip(x_P, V)] + # Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape + y_F = F(x_F) + # Stash metadata to the PyTorch context + ctx.metadata.update(metadata) + # Convert Firedrake output to PyTorch + y_P = to_ml_backend(y_F) + return y_P.detach() + + @staticmethod + def backward(ctx, grad_output): + """Backward pass of the PyTorch custom operator.""" + F = ctx.metadata['F'] + V = ctx.metadata['V_output'] + # Convert PyTorch gradient to Firedrake + adj_input = from_ml_backend(grad_output, V) + if isinstance(adj_input, Function): + adj_input = adj_input.vector() + + # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional + adj_output = F.derivative(adj_input=adj_input) + + # Tuplify adjoint output + adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output + + # None is for metadata arg in `forward` + return None, *[to_ml_backend(di) for di in adj_output] + + +def torch_operator(F): + """Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator + whose inputs and outputs are PyTorch tensors. + """ + Citations().register("Bouziani2023") + + if not isinstance(F, ReducedFunctional): + raise ValueError("F must be a ReducedFunctional") + + V_output = extract_function_space(F.functional) + V_controls = [c.control.function_space() for c in F.controls] + metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output} + F_P = partial(FiredrakeTorchOperator.apply, metadata) + return F_P + + +def extract_function_space(x): + """Get function space out of x""" + if isinstance(x, Function): + return x.function_space() + elif isinstance(x, Vector): + return extract_function_space(x.function) + elif isinstance(x, float): + return None + else: + raise ValueError("Cannot infer the function space of %s" % x) + + +def to_ml_backend(x, gather=False, batched=True, **kwargs): + r"""Convert a Firedrake object `x` into a PyTorch tensor. + + :arg x: Firedrake object (Function, Vector, Constant) + :kwarg gather: if True, gather data from all processes + :kwarg batched: if True, add a batch dimension to the tensor + :kwarg kwargs: additional arguments to be passed to torch.Tensor constructor + - device: device on which the tensor is allocated (default: "cpu") + - dtype: the desired data type of returned tensor (default: type of x.dat.data) + - requires_grad: if the tensor should be annotated (default: False) + """ + if isinstance(x, (Function, Vector)): + if gather: + # Gather data from all processes + x_P = torch.tensor(x.vector().gather(), **kwargs) + else: + # Use local data + x_P = torch.tensor(x.vector().get_local(), **kwargs) + if batched: + # Default behaviour: add batch dimension after converting to PyTorch + return x_P[None, :] + return x_P + elif isinstance(x, Constant): + return torch.tensor(x.values(), **kwargs) + elif isinstance(x, (float, int)): + if isinstance(x, float): + # Set double-precision + kwargs['dtype'] = torch.double + return torch.tensor(x, **kwargs) + else: + raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) + + +def from_ml_backend(x, V=None): + r"""Convert a PyTorch tensor `x` into a Firedrake object. + + :arg x: PyTorch tensor (torch.Tensor) + :kwarg V: function space of the corresponding Function or None when `x` is to be mapped to a Constant + """ + if x.device.type != "cpu": + raise NotImplementedError("Firedrake does not support GPU tensors") + + if V is None: + val = x.detach().numpy() + if val.shape == (1,): + val = val[0] + return Constant(val) + else: + x = x.detach().numpy() + x_F = Function(V, dtype=x.dtype) + x_F.vector().set_local(x) + return x_F diff --git a/firedrake/ml/pytorch/__init__.py b/firedrake/ml/pytorch/__init__.py deleted file mode 100644 index 8b9915005f..0000000000 --- a/firedrake/ml/pytorch/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -try: - import torch - del torch -except ImportError: - raise ImportError("PyTorch is not installed and is required to use the FiredrakeTorchOperator.") - -from firedrake.ml.pytorch.pytorch_custom_operator import torch_operator # noqa: F401 diff --git a/firedrake/ml/pytorch/backend.py b/firedrake/ml/pytorch/backend.py deleted file mode 100644 index 8e491a14b1..0000000000 --- a/firedrake/ml/pytorch/backend.py +++ /dev/null @@ -1,70 +0,0 @@ -from firedrake.function import Function -from firedrake.vector import Vector -from firedrake.constant import Constant -from firedrake.ml.backend_base import AbstractMLBackend - -import firedrake.utils as utils - - -class PytorchBackend(AbstractMLBackend): - - @utils.cached_property - def backend(self): - import torch - return torch - - def custom_operator(self, *args, **kwargs): - from firedrake.ml.pytorch.pytorch_custom_operator import FiredrakeTorchOperator - return FiredrakeTorchOperator.apply(*args, **kwargs) - - def to_ml_backend(self, x, gather=False, batched=True, **kwargs): - r"""Convert a Firedrake object `x` into a PyTorch tensor. - - :arg x: Firedrake object (Function, Vector, Constant) - :kwarg gather: if True, gather data from all processes - :kwarg batched: if True, add a batch dimension to the tensor - :kwarg kwargs: additional arguments to be passed to torch.Tensor constructor - - device: device on which the tensor is allocated (default: "cpu") - - dtype: the desired data type of returned tensor (default: type of x.dat.data) - - requires_grad: if the tensor should be annotated (default: False) - """ - if isinstance(x, (Function, Vector)): - if gather: - # Gather data from all processes - x_P = self.backend.tensor(x.vector().gather(), **kwargs) - else: - # Use local data - x_P = self.backend.tensor(x.vector().get_local(), **kwargs) - if batched: - # Default behaviour: add batch dimension after converting to PyTorch - return x_P[None, :] - return x_P - elif isinstance(x, Constant): - return self.backend.tensor(x.values(), **kwargs) - elif isinstance(x, (float, int)): - if isinstance(x, float): - # Set double-precision - kwargs['dtype'] = self.backend.double - return self.backend.tensor(x, **kwargs) - else: - raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) - - def from_ml_backend(self, x, V=None): - r"""Convert a PyTorch tensor `x` into a Firedrake object. - - :arg x: PyTorch tensor (torch.Tensor) - :kwarg V: function space of the corresponding Function or None when `x` is to be mapped to a Constant - """ - if x.device.type != "cpu": - raise NotImplementedError("Firedrake does not support GPU tensors") - - if V is None: - val = x.detach().numpy() - if val.shape == (1,): - val = val[0] - return Constant(val) - else: - x = x.detach().numpy() - x_F = Function(V, dtype=x.dtype) - x_F.vector().set_local(x) - return x_F diff --git a/firedrake/ml/pytorch/pytorch_custom_operator.py b/firedrake/ml/pytorch/pytorch_custom_operator.py deleted file mode 100644 index addf4b1e30..0000000000 --- a/firedrake/ml/pytorch/pytorch_custom_operator.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import collections -from functools import partial - -from firedrake.ml import load_backend -from firedrake.function import Function -from firedrake_citations import Citations - -from pyadjoint.reduced_functional import ReducedFunctional - - -Citations().add("Bouziani2023", """ -@inproceedings{Bouziani2023, - title = {Physics-driven machine learning models coupling {PyTorch} and {Firedrake}}, - author = {Bouziani, Nacime and Ham, David A.}, - booktitle = {{ICLR} 2023 {Workshop} on {Physics} for {Machine} {Learning}}, - year = {2023}, - doi = {10.48550/arXiv.2303.06871} -} -""") - - -backend = load_backend("pytorch") - - -class FiredrakeTorchOperator(torch.autograd.Function): - """ - PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. - `FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward - passes by directly calling the reduced functional F. - - Inputs: - metadata: dictionary used to stash Firedrake objects. - x_P: PyTorch tensors representing the inputs to the Firedrake operator F - - Outputs: - y_P: PyTorch tensor representing the output of the Firedrake operator F - """ - - def __init__(self): - super(FiredrakeTorchOperator, self).__init__() - - # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') - @staticmethod - def forward(ctx, metadata, *x_P): - """Forward pass of the PyTorch custom operator.""" - F = metadata['F'] - V = metadata['V_controls'] - # Convert PyTorch input (i.e. controls) to Firedrake - x_F = [backend.from_ml_backend(xi, Vi) for xi, Vi in zip(x_P, V)] - # Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape - y_F = F(x_F) - # Stash metadata to the PyTorch context - ctx.metadata.update(metadata) - # Convert Firedrake output to PyTorch - y_P = backend.to_ml_backend(y_F) - return y_P.detach() - - @staticmethod - def backward(ctx, grad_output): - """Backward pass of the PyTorch custom operator.""" - F = ctx.metadata['F'] - V = ctx.metadata['V_output'] - # Convert PyTorch gradient to Firedrake - adj_input = backend.from_ml_backend(grad_output, V) - if isinstance(adj_input, Function): - adj_input = adj_input.vector() - - # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input) - - # Tuplify adjoint output - adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output - - # None is for metadata arg in `forward` - return None, *[backend.to_ml_backend(di) for di in adj_output] - - -def torch_operator(F): - """Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator - whose inputs and outputs are PyTorch tensors. - """ - Citations().register("Bouziani2023") - - if not isinstance(F, ReducedFunctional): - raise ValueError("F must be a ReducedFunctional") - - V_output = backend.function_space(F.functional) - V_controls = [c.control.function_space() for c in F.controls] - metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output} - F_P = partial(backend.custom_operator, metadata) - return F_P diff --git a/tests/regression/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py index 862b76b984..d4761b2c40 100644 --- a/tests/regression/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -5,13 +5,11 @@ try: - from firedrake.ml import load_backend, torch_operator + from firedrake.ml.pytorch import * import torch import torch.nn.functional as torch_func from torch.nn import Module, Flatten, Linear - pytorch_backend = load_backend("pytorch") - class EncoderDecoder(Module): """Build a simple toy model""" @@ -124,7 +122,7 @@ def test_pytorch_loss_backward(V, f_exact): assert all([pi.grad is None for pi in model.parameters()]) # Convert f_exact to torch.Tensor - f_P = pytorch_backend.to_ml_backend(f_exact) + f_P = to_ml_backend(f_exact) # Forward pass u_P = model(f_P) @@ -154,9 +152,9 @@ def test_pytorch_loss_backward(V, f_exact): assert all([pi.grad is not None for pi in model.parameters()]) # -- Check forward operator -- # - u = pytorch_backend.from_ml_backend(u_P, V) + u = from_ml_backend(u_P, V) residual = poisson_residual(u, f_exact, V) - residual_P_exact = pytorch_backend.to_ml_backend(residual) + residual_P_exact = to_ml_backend(residual) assert (residual_P - residual_P_exact).detach().norm() < 1e-10 @@ -181,7 +179,7 @@ def test_firedrake_loss_backward(V): u = Function(V) # Convert f to torch.Tensor - u_P = pytorch_backend.to_ml_backend(u) + u_P = to_ml_backend(u) # Forward pass f_P = model(u_P) @@ -209,9 +207,9 @@ def test_firedrake_loss_backward(V): assert all([pi.grad is not None for pi in model.parameters()]) # -- Check forward operator -- # - f = pytorch_backend.from_ml_backend(f_P, V) + f = from_ml_backend(f_P, V) loss = solve_poisson(f, V) - loss_P_exact = pytorch_backend.to_ml_backend(loss) + loss_P_exact = to_ml_backend(loss) assert (loss_P - loss_P_exact).detach().norm() < 1e-10 From ef5b67b680a6e324ec99788cc90d4b9dcc9a6741 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 18 May 2023 13:01:06 +0100 Subject: [PATCH 43/48] Address comments --- firedrake/ml/__init__.py | 1 - firedrake/ml/pytorch.py | 120 +++++++++++++++------- tests/regression/test_pytorch_coupling.py | 12 +-- 3 files changed, 91 insertions(+), 42 deletions(-) diff --git a/firedrake/ml/__init__.py b/firedrake/ml/__init__.py index 15ae7dfae2..e69de29bb2 100644 --- a/firedrake/ml/__init__.py +++ b/firedrake/ml/__init__.py @@ -1 +0,0 @@ -from firedrake.ml.pytorch import FiredrakeTorchOperator, torch_operator, from_ml_backend, to_ml_backend # noqa: F401 diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index 304c732a17..012404aae0 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -25,18 +25,25 @@ """) -class FiredrakeTorchOperator(torch.autograd.Function): - """ - PyTorch custom operator representing a set of Firedrake operations expressed as a ReducedFunctional F. - `FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward - passes by directly calling the reduced functional F. +__all__ = ['FiredrakeTorchOperator', 'torch_operator', 'to_torch', 'from_torch'] - Inputs: - metadata: dictionary used to stash Firedrake objects. - x_P: PyTorch tensors representing the inputs to the Firedrake operator F - Outputs: - y_P: PyTorch tensor representing the output of the Firedrake operator F +class FiredrakeTorchOperator(torch.autograd.Function): + """PyTorch custom operator representing a set of Firedrake operations expressed as a reduced functional `F`. + `FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward + passes by directly calling the reduced functional `F`. + + Parameters + ---------- + metadata : dict + Dictionary used to stash Firedrake objects. + x_P : list of torch.Tensor + PyTorch tensors representing the inputs to the Firedrake operator `F`. + + Returns + ------- + torch.Tensor + PyTorch tensor representing the output of the Firedrake operator `F`. """ def __init__(self): @@ -45,26 +52,28 @@ def __init__(self): # This method is wrapped by something cancelling annotation (probably 'with torch.no_grad()') @staticmethod def forward(ctx, metadata, *x_P): - """Forward pass of the PyTorch custom operator.""" + """Forward pass of the PyTorch custom operator. + """ F = metadata['F'] V = metadata['V_controls'] # Convert PyTorch input (i.e. controls) to Firedrake - x_F = [from_ml_backend(xi, Vi) for xi, Vi in zip(x_P, V)] + x_F = [from_torch(xi, Vi) for xi, Vi in zip(x_P, V)] # Forward operator: delegated to pyadjoint.ReducedFunctional which recomputes the blocks on the tape y_F = F(x_F) # Stash metadata to the PyTorch context ctx.metadata.update(metadata) # Convert Firedrake output to PyTorch - y_P = to_ml_backend(y_F) + y_P = to_torch(y_F) return y_P.detach() @staticmethod def backward(ctx, grad_output): - """Backward pass of the PyTorch custom operator.""" + """Backward pass of the PyTorch custom operator. + """ F = ctx.metadata['F'] V = ctx.metadata['V_output'] # Convert PyTorch gradient to Firedrake - adj_input = from_ml_backend(grad_output, V) + adj_input = from_torch(grad_output, V) if isinstance(adj_input, Function): adj_input = adj_input.vector() @@ -75,47 +84,79 @@ def backward(ctx, grad_output): adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output # None is for metadata arg in `forward` - return None, *[to_ml_backend(di) for di in adj_output] + return None, *[to_torch(di) for di in adj_output] def torch_operator(F): - """Operator that converts a pyadjoint.ReducedFunctional into a firedrake.FiredrakeTorchOperator + """Operator that converts a :class:`pyadjoint.ReducedFunctional` into a :class:`~FiredrakeTorchOperator` whose inputs and outputs are PyTorch tensors. + + Parameters + ---------- + F : pyadjoint.ReducedFunctional + The reduced functional to wrap. + + Returns + ------- + firedrake.ml.pytorch.FiredrakeTorchOperator + A PyTorch custom operator that wraps the reduced functional `F`. """ Citations().register("Bouziani2023") if not isinstance(F, ReducedFunctional): raise ValueError("F must be a ReducedFunctional") - V_output = extract_function_space(F.functional) + V_output = _extract_function_space(F.functional) V_controls = [c.control.function_space() for c in F.controls] metadata = {'F': F, 'V_controls': V_controls, 'V_output': V_output} F_P = partial(FiredrakeTorchOperator.apply, metadata) return F_P -def extract_function_space(x): - """Get function space out of x""" +def _extract_function_space(x): + """Extract the function space from a Firedrake object `x`. + + Parameters + ---------- + x : float, firedrake.Function or firedrake.Vector + Firedrake object from which to extract the function space. + + Returns + ------- + firedrake.functionspaceimpl.WithGeometry or None + Extracted function space. + """ if isinstance(x, Function): return x.function_space() elif isinstance(x, Vector): - return extract_function_space(x.function) + return _extract_function_space(x.function) elif isinstance(x, float): return None else: raise ValueError("Cannot infer the function space of %s" % x) -def to_ml_backend(x, gather=False, batched=True, **kwargs): - r"""Convert a Firedrake object `x` into a PyTorch tensor. - - :arg x: Firedrake object (Function, Vector, Constant) - :kwarg gather: if True, gather data from all processes - :kwarg batched: if True, add a batch dimension to the tensor - :kwarg kwargs: additional arguments to be passed to torch.Tensor constructor - - device: device on which the tensor is allocated (default: "cpu") - - dtype: the desired data type of returned tensor (default: type of x.dat.data) - - requires_grad: if the tensor should be annotated (default: False) +def to_torch(x, gather=False, batched=True, **kwargs): + """Convert a Firedrake object `x` into a PyTorch tensor. + + Parameters + ---------- + x : firedrake.Function, firedrake.Vector or firedrake.Constant + Firedrake object to convert. + gather : bool, optional + If True, gather data from all processes + batched : bool, optional + If True, add a batch dimension to the tensor + kwargs : dict, optional + Additional arguments to be passed to the :class:`torch.Tensor` constructor such as: + - device: device on which the tensor is allocated (default: "cpu") + - dtype: the desired data type of returned tensor (default: type of `x.dat.data`) + - requires_grad: if the tensor should be annotated (default: False) + + Returns + ------- + torch.Tensor + PyTorch tensor representing the Firedrake object `x`. """ if isinstance(x, (Function, Vector)): if gather: @@ -139,11 +180,20 @@ def to_ml_backend(x, gather=False, batched=True, **kwargs): raise ValueError("Cannot convert %s to a torch tensor" % str(type(x))) -def from_ml_backend(x, V=None): - r"""Convert a PyTorch tensor `x` into a Firedrake object. +def from_torch(x, V=None): + """Convert a PyTorch tensor `x` into a Firedrake object. + + Parameters + ---------- + x : torch.Tensor + PyTorch tensor to convert. + V : firedrake.functionspaceimpl.WithGeometry or None, optional + Function space of the corresponding :class:`firedrake.Function` or None when `x` is to be mapped to a :class:`firedrake.Constant`. - :arg x: PyTorch tensor (torch.Tensor) - :kwarg V: function space of the corresponding Function or None when `x` is to be mapped to a Constant + Returns + ------- + firedrake.Function or firedrake.Constant + Firedrake object representing the PyTorch tensor `x`. """ if x.device.type != "cpu": raise NotImplementedError("Firedrake does not support GPU tensors") diff --git a/tests/regression/test_pytorch_coupling.py b/tests/regression/test_pytorch_coupling.py index d4761b2c40..0f5e6ee2d4 100644 --- a/tests/regression/test_pytorch_coupling.py +++ b/tests/regression/test_pytorch_coupling.py @@ -122,7 +122,7 @@ def test_pytorch_loss_backward(V, f_exact): assert all([pi.grad is None for pi in model.parameters()]) # Convert f_exact to torch.Tensor - f_P = to_ml_backend(f_exact) + f_P = to_torch(f_exact) # Forward pass u_P = model(f_P) @@ -152,9 +152,9 @@ def test_pytorch_loss_backward(V, f_exact): assert all([pi.grad is not None for pi in model.parameters()]) # -- Check forward operator -- # - u = from_ml_backend(u_P, V) + u = from_torch(u_P, V) residual = poisson_residual(u, f_exact, V) - residual_P_exact = to_ml_backend(residual) + residual_P_exact = to_torch(residual) assert (residual_P - residual_P_exact).detach().norm() < 1e-10 @@ -179,7 +179,7 @@ def test_firedrake_loss_backward(V): u = Function(V) # Convert f to torch.Tensor - u_P = to_ml_backend(u) + u_P = to_torch(u) # Forward pass f_P = model(u_P) @@ -207,9 +207,9 @@ def test_firedrake_loss_backward(V): assert all([pi.grad is not None for pi in model.parameters()]) # -- Check forward operator -- # - f = from_ml_backend(f_P, V) + f = from_torch(f_P, V) loss = solve_poisson(f, V) - loss_P_exact = to_ml_backend(loss) + loss_P_exact = to_torch(loss) assert (loss_P - loss_P_exact).detach().norm() < 1e-10 From 78bdd898626a1f8347b387bb4b52415337ce504f Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 18 May 2023 13:20:43 +0100 Subject: [PATCH 44/48] Fix class reference in from_torch docstring --- firedrake/ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index 012404aae0..1aa456ff29 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -188,7 +188,7 @@ def from_torch(x, V=None): x : torch.Tensor PyTorch tensor to convert. V : firedrake.functionspaceimpl.WithGeometry or None, optional - Function space of the corresponding :class:`firedrake.Function` or None when `x` is to be mapped to a :class:`firedrake.Constant`. + Function space of the corresponding :class:`~firedrake.Function` or None when `x` is to be mapped to a :class:`~firedrake.Constant`. Returns ------- From ec74361e8b4c6ee28f2924e3a4dfcbc259f04bf3 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 18 May 2023 13:44:52 +0100 Subject: [PATCH 45/48] Remove optional in docstrings --- firedrake/ml/pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index 1aa456ff29..88939f6009 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -143,11 +143,11 @@ def to_torch(x, gather=False, batched=True, **kwargs): ---------- x : firedrake.Function, firedrake.Vector or firedrake.Constant Firedrake object to convert. - gather : bool, optional + gather : bool If True, gather data from all processes - batched : bool, optional + batched : bool If True, add a batch dimension to the tensor - kwargs : dict, optional + kwargs : dict Additional arguments to be passed to the :class:`torch.Tensor` constructor such as: - device: device on which the tensor is allocated (default: "cpu") - dtype: the desired data type of returned tensor (default: type of `x.dat.data`) @@ -187,7 +187,7 @@ def from_torch(x, V=None): ---------- x : torch.Tensor PyTorch tensor to convert. - V : firedrake.functionspaceimpl.WithGeometry or None, optional + V : firedrake.functionspaceimpl.WithGeometry or None Function space of the corresponding :class:`~firedrake.Function` or None when `x` is to be mapped to a :class:`~firedrake.Constant`. Returns From c0a7dcd13ce6fdac5cafdca2f22839a951f42922 Mon Sep 17 00:00:00 2001 From: "David A. Ham" Date: Thu, 18 May 2023 13:56:11 +0100 Subject: [PATCH 46/48] Update firedrake/ml/pytorch.py Missing blank line. --- firedrake/ml/pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index 88939f6009..e91a87b8c8 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -30,6 +30,7 @@ class FiredrakeTorchOperator(torch.autograd.Function): """PyTorch custom operator representing a set of Firedrake operations expressed as a reduced functional `F`. + `FiredrakeTorchOperator` is a wrapper around :class:`torch.autograd.Function` that executes forward and backward passes by directly calling the reduced functional `F`. From 8d04288703ed5b5a358a2e9c783d3985e2caeb56 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 18 May 2023 14:03:19 +0100 Subject: [PATCH 47/48] Address comments (part 2) --- firedrake/ml/pytorch.py | 5 +++-- tests/conftest.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index e91a87b8c8..8b3a05abfd 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -89,8 +89,9 @@ def backward(ctx, grad_output): def torch_operator(F): - """Operator that converts a :class:`pyadjoint.ReducedFunctional` into a :class:`~FiredrakeTorchOperator` - whose inputs and outputs are PyTorch tensors. + """Cast a Firedrake reduced functional to a PyTorch operator. + + The resulting :class:`~FiredrakeTorchOperator` will take PyTorch tensors as inputs and return PyTorch tensors as outputs. Parameters ---------- diff --git a/tests/conftest.py b/tests/conftest.py index 65d47ab038..0c96cb5898 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def pytest_collection_modifyitems(session, config, items): from firedrake.utils import complex_mode, SLATE_SUPPORTS_COMPLEX try: - import firedrake.ml as fd_ml + import firedrake.ml.pytorch as fd_ml del fd_ml ml_backend = True except ImportError: From ce8a99b38c11199bcebcc8e1eaa91744f539653d Mon Sep 17 00:00:00 2001 From: nbouziani Date: Thu, 18 May 2023 15:08:55 +0100 Subject: [PATCH 48/48] Fix docs --- firedrake/ml/pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/firedrake/ml/pytorch.py b/firedrake/ml/pytorch.py index 8b3a05abfd..b510760f7c 100644 --- a/firedrake/ml/pytorch.py +++ b/firedrake/ml/pytorch.py @@ -120,7 +120,7 @@ def _extract_function_space(x): Parameters ---------- - x : float, firedrake.Function or firedrake.Vector + x : float, firedrake.function.Function or firedrake.vector.Vector Firedrake object from which to extract the function space. Returns @@ -143,7 +143,7 @@ def to_torch(x, gather=False, batched=True, **kwargs): Parameters ---------- - x : firedrake.Function, firedrake.Vector or firedrake.Constant + x : firedrake.function.Function, firedrake.vector.Vector or firedrake.constant.Constant Firedrake object to convert. gather : bool If True, gather data from all processes @@ -190,11 +190,11 @@ def from_torch(x, V=None): x : torch.Tensor PyTorch tensor to convert. V : firedrake.functionspaceimpl.WithGeometry or None - Function space of the corresponding :class:`~firedrake.Function` or None when `x` is to be mapped to a :class:`~firedrake.Constant`. + Function space of the corresponding :class:`.Function` or None when `x` is to be mapped to a :class:`.Constant`. Returns ------- - firedrake.Function or firedrake.Constant + firedrake.function.Function or firedrake.constant.Constant Firedrake object representing the PyTorch tensor `x`. """ if x.device.type != "cpu":