Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FML #3041

Merged
merged 12 commits into from
Oct 25, 2023
132 changes: 85 additions & 47 deletions firedrake/fml/form_manipulation_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import functools
import operator
from firedrake import Constant, Function
from collections.abc import Callable, Sequence, Mapping
from typing import Any, Union


__all__ = ["Label", "Term", "LabelledForm", "identity", "drop", "all_terms",
Expand All @@ -26,26 +28,26 @@ class Term(object):

__slots__ = ["form", "labels"]

def __init__(self, form, label_dict=None):
def __init__(self, form: ufl.Form, label_dict: Mapping = None):
"""
Parameters
----------
form : ufl.Form
form
The form for this terms.
label_dict : `dict`, optional
label_dict
Dictionary of key-value pairs corresponding to current form labels.
Defaults to None.
"""
self.form = form
self.labels = label_dict or {}

def get(self, label):
def get(self, label: "Label") -> Any:
"""
Returns the value of a label.

Parameters
----------
label : Label
label
The label to return the value of.

Returns
Expand All @@ -55,36 +57,40 @@ def get(self, label):
"""
return self.labels.get(label.label)

def has_label(self, *labels, return_tuple=False):
def has_label(
self,
*labels: Union[Sequence["Label"], "Label"],
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
dham marked this conversation as resolved.
Show resolved Hide resolved
return_tuple: bool = False
) -> Union[tuple[bool], bool]:
"""
Whether the term has the specified labels attached to it.

Parameters
----------
*labels : Label
*labels
A label or series of labels. A tuple is automatically returned if
multiple labels are provided as arguments.
return_tuple : `bool`, optional
return_tuple
If True, forces a tuple to be returned even if only one label is
provided as an argument. Defaults to False.

Returns
-------
bool or tuple
bool
Booleans corresponding to whether the term has the specified labels.
"""
if len(labels) == 1 and not return_tuple:
return labels[0].label in self.labels
else:
return tuple(self.has_label(l) for l in labels)

def __add__(self, other):
def __add__(self, other: Union["Term", "LabelledForm"]) -> "LabelledForm":
"""
Adds a term or labelled form to this term.

Parameters
----------
other : Term or LabelledForm
other
The term or labelled form to add to this term.

Returns
Expand All @@ -105,13 +111,13 @@ def __add__(self, other):

__radd__ = __add__

def __sub__(self, other):
def __sub__(self, other: Union["Term", "LabelledForm"]) -> "LabelledForm":
"""
Subtracts a term or labelled form from this term.

Parameters
----------
other : Term or LabelledForm
other
The term or labelled form to subtract from this term.

Returns
Expand All @@ -122,13 +128,16 @@ def __sub__(self, other):
other = other * Constant(-1.0)
return self + other

def __mul__(self, other):
def __mul__(
self,
other: Union[float, Constant, ufl.algebra.Product]
) -> "Term":
"""
Multiplies this term by another quantity.

Parameters
----------
other : float, Constant or ufl.algebra.Product
other
The quantity to multiply this term by.

Returns
Expand All @@ -140,13 +149,16 @@ def __mul__(self, other):

__rmul__ = __mul__

def __truediv__(self, other):
def __truediv__(
self,
other: Union[float, Constant, ufl.algebra.Product]
) -> "Term":
"""
Divides this term by another quantity.

Parameters
----------
other : float, Constant or ufl.algebra.Product
other
The quantity to divide this term by.

Returns
Expand Down Expand Up @@ -175,7 +187,7 @@ class LabelledForm(object):
"""
__slots__ = ["terms"]

def __init__(self, *terms):
def __init__(self, *terms: Sequence[Term]):
"""
Parameters
----------
Expand All @@ -193,13 +205,16 @@ def __init__(self, *terms):
raise TypeError('Can only pass terms or a LabelledForm to LabelledForm')
self.terms = list(terms)

def __add__(self, other):
def __add__(
self,
other: Union[ufl.Form, Term, "LabelledForm"]
) -> "LabelledForm":
"""
Adds a form, term or labelled form to this labelled form.

Parameters
----------
other : ufl.Form, Term or LabelledForm
other
The form, term or labelled form to add to this labelled form.

Returns
Expand All @@ -220,13 +235,16 @@ def __add__(self, other):

__radd__ = __add__

def __sub__(self, other):
def __sub__(
self,
other: Union[ufl.Form, Term, "LabelledForm"]
) -> "LabelledForm":
"""
Subtracts a form, term or labelled form from this labelled form.

Parameters
----------
other : ufl.Form, Term or LabelledForm
other
The form, term or labelled form to subtract from this labelled form.

Returns
Expand All @@ -244,13 +262,16 @@ def __sub__(self, other):
# Make new Term for other and subtract it
return LabelledForm(*self, Term(Constant(-1.)*other))

def __mul__(self, other):
def __mul__(
self,
other: Union[float, Constant, ufl.algebra.Product]
) -> "LabelledForm":
"""
Multiplies this labelled form by another quantity.

Parameters
----------
other : float, Constant or ufl.algebra.Product
other
The quantity to multiply this labelled form by. All terms in the
form are multiplied.

Expand All @@ -261,13 +282,16 @@ def __mul__(self, other):
"""
return self.label_map(all_terms, lambda t: Term(other*t.form, t.labels))

def __truediv__(self, other):
def __truediv__(
self,
other: Union[float, Constant, ufl.algebra.Product]
) -> "LabelledForm":
"""
Divides this labelled form by another quantity.

Parameters
----------
other : float, Constant or ufl.algebra.Product
other
The quantity to divide this labelled form by. All terms in the form
are divided.

Expand All @@ -280,27 +304,31 @@ def __truediv__(self, other):

__rmul__ = __mul__

def __iter__(self):
def __iter__(self) -> Sequence:
"""Returns an iterable of the terms in the labelled form."""
return iter(self.terms)

def __len__(self):
def __len__(self) -> int:
"""Returns the number of terms in the labelled form."""
return len(self.terms)

def label_map(self, term_filter, map_if_true=identity,
map_if_false=identity):
def label_map(
self,
term_filter: Callable,
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
map_if_true: Callable = identity,
map_if_false: Callable = identity
) -> "LabelledForm":
"""
Maps selected terms in the labelled form, returning a new labelled form.

Parameters
----------
term_filter : `callable`
term_filter
A function to filter the labelled form's terms.
map_if_true : `callable`, optional
map_if_true
How to map the terms for which the term_filter returns True.
Defaults to identity.
map_if_false : `callable`, optional
map_if_false
How to map the terms for which the term_filter returns False.
Defaults to identity.

Expand Down Expand Up @@ -329,7 +357,7 @@ def label_map(self, term_filter, map_if_true=identity,
return new_labelled_form

@property
def form(self):
def form(self) -> ufl.Form:
"""
Provides the whole form from the labelled form.

Expand All @@ -355,32 +383,42 @@ class Label(object):

__slots__ = ["label", "default_value", "value", "validator"]

def __init__(self, label, *, value=True, validator=None):
def __init__(
self,
label,
*,
value: Any = True,
validator: Union[Callable, None] = None
):
"""
Parameters
----------
label : str
label
The name of the label.
value : `any`, optional
value
The value for the label to take. Can be any type (subject to the
validator). Defaults to True.
validator : `callable`, optional
validator
Function to check the validity of any value later passed to the
label. Defaults to None.
"""
self.label = label
self.default_value = value
self.validator = validator

def __call__(self, target, value=None):
def __call__(
self,
target: Union[ufl.Form, Term, LabelledForm],
value: Any = None
) -> Union[Term, LabelledForm]:
"""
Applies the label to a form or term.

Parameters
----------
target : ufl.Form, Term or LabelledForm
target
The form, term or labelled form to be labelled.
value : Any, optional
value
The value to attach to this label. Defaults to None.

Raises
Expand All @@ -391,7 +429,7 @@ def __call__(self, target, value=None):

Returns
-------
Term or LabelledForm
Union[Term, LabelledForm]
A Term is returned if the target is a Term,
otherwise a LabelledForm is returned.
"""
Expand All @@ -414,7 +452,7 @@ def __call__(self, target, value=None):
else:
raise ValueError("Unable to label %s" % target)

def remove(self, target):
def remove(self, target: Union[Term, LabelledForm]):
"""
Removes a label from a term or labelled form.

Expand All @@ -423,7 +461,7 @@ def remove(self, target):

Parameters
----------
target : Term or LabelledForm
target
Term or labelled form to have this label removed from.

Raises
Expand All @@ -444,7 +482,7 @@ def remove(self, target):
else:
raise ValueError("Unable to unlabel %s" % target)

def update_value(self, target, new):
def update_value(self, target: Union[Term, LabelledForm], new: Any):
"""
Updates the label of a term or labelled form.

Expand All @@ -453,9 +491,9 @@ def update_value(self, target, new):

Parameters
----------
target : Term or LabelledForm
target
Term or labelled form to have this label updated.
new : Any
new
The new value for this label to take. The type is subject to the
label's validator (if it has one).

Expand Down
Loading
Loading