Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds typing to expression tree #3578

Merged
merged 38 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0f28844
(wip) First batch of typing
pipliggins May 19, 2023
ea172e0
(wip): Subset edits to typing after adding mypy
pipliggins Jun 7, 2023
f5dd303
(wip): more mypy edits
pipliggins Jun 8, 2023
8854bdc
(wip) More mypy changes
pipliggins Jun 8, 2023
e41a924
(wip): Passing mypy with --allow-redefinition
pipliggins Jun 8, 2023
79aefb8
Fixes, remove most type:ignores (not passing)
pipliggins Nov 27, 2023
ad27152
Merge branch 'develop' into expression-tree-typing
pipliggins Nov 27, 2023
9b281e2
Merge branch 'develop' into expression-tree-typing
pipliggins Dec 1, 2023
f2faa8a
style: pre-commit fixes
pre-commit-ci[bot] Dec 1, 2023
da3b072
edit _from_json to fix subtype incompatibility mypy errors
pipliggins Dec 1, 2023
1dfb1b4
Misc type hinting fixes for mypy (30 remaining)
pipliggins Dec 14, 2023
d9e3ae2
style: pre-commit fixes
pre-commit-ci[bot] Dec 14, 2023
77f00a6
More misc edits to reduce mypy errors (19 remaining)
pipliggins Dec 15, 2023
69f1124
Fix pre-commit issues
pipliggins Dec 15, 2023
44e5161
Edit imports
pipliggins Dec 15, 2023
811cf8e
style: pre-commit fixes
pre-commit-ci[bot] Dec 15, 2023
bb75a85
Remove 'assert' and typing Tuples
pipliggins Dec 15, 2023
07f48d9
Update typing syntax to match 3.10+ style, add __future__ imports
pipliggins Dec 15, 2023
e9a77a8
Mypy passes with `mypy pybamm` command
pipliggins Jan 12, 2024
b0234cb
Merge branch 'develop' into expression-tree-typing
pipliggins Jan 16, 2024
4d6fb32
style: pre-commit fixes
pre-commit-ci[bot] Jan 16, 2024
64f3fc9
Mypy fixes after merge
pipliggins Jan 16, 2024
c9a0ff1
Fix coverage issues
pipliggins Jan 16, 2024
df6501f
style: pre-commit fixes
pre-commit-ci[bot] Jan 16, 2024
ed2288b
Remove unnecessary 'hints.py' file
pipliggins Jan 18, 2024
062bd00
Merge branch 'develop' into expression-tree-typing
pipliggins Jan 18, 2024
5333d42
Further specify types for domain/auxiliary_domain/domains
pipliggins Jan 18, 2024
f7ff456
style: pre-commit fixes
pre-commit-ci[bot] Jan 18, 2024
9cef34b
Stop ignoring UP007 in ruff, as per #3579
pipliggins Jan 19, 2024
21e4107
Move mypy.ini to pyproject.toml
pipliggins Jan 22, 2024
d1ae819
Fix some ignored type errors
pipliggins Jan 25, 2024
fa5d27d
Merge branch 'develop' into expression-tree-typing
pipliggins Jan 25, 2024
d80c981
Move common type definitions to type_definitions.py
pipliggins Feb 8, 2024
8b0a2aa
Replace numbers.Number type hints to work with static checkers
pipliggins Feb 8, 2024
102c2d6
Add from __future__ to type_definitions.py
pipliggins Feb 8, 2024
38a3713
Add TypeAlias hint
pipliggins Feb 9, 2024
ea03b4f
Use typing List/Dict
pipliggins Feb 9, 2024
eeceaa7
Merge branch 'develop' into expression-tree-typing
martinjrobins Feb 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pybamm/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _reset(self):
self.register("Sulzer2021")
self.register("Harris2020")

@staticmethod
def _caller_name():
"""
Returns the qualified name of classes that call :meth:`register` internally.
Expand Down
2 changes: 1 addition & 1 deletion pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str],
operating_conditions: list[str | tuple[str]],
martinjrobins marked this conversation as resolved.
Show resolved Hide resolved
period: str = "1 minute",
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
45 changes: 24 additions & 21 deletions pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
#
# NumpyArray class
#
from __future__ import annotations
import numpy as np
from scipy.sparse import csr_matrix, issparse
from typing import TYPE_CHECKING

import pybamm
from pybamm.util import have_optional_dependency

if TYPE_CHECKING: # pragma: no cover
import sympy
Comment on lines +13 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to enable type checking with a global variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to above - this is a statement used by mypy. In this case sympy is used in the file only as a type hint, not during runtime. TYPE_CHECKING is a constant assumed to be True by static type checkers, but is false at runtime; so sympy will only be imported if a type checker like mypy is being run. https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would prefer to just have sympy always imported. If it is used in the type annotations, then why not just always have it there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would make sympy a required dependency. Sympy is heavy and I would avoid making it a required dependency (only a subset of features in PyBaMM require sympy)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that if expression tree is using sympy types, then it is really a dependency. I have not downloaded this branch locally to double check, but I would guess that my IDE would complain if I did not have sympy installed too.

As we add type checking it is going to expose more things that should probably be real dependencies rather than optional ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sympy is only used in the expression tree for to_equation() to use with LaTeX I believe, rather than the solvers, hence the optional dependency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my point was more that we might have to rethink our optional dependencies if we are using static types. Expression tree depends on sympy if it is using sympy types, but by putting this check here we are pretending that it doesn't depend on sympy.

It is too much for this task, but if we want a typed library and not have all the dependencies for types then we might need to think about how to better separate our subcomponents. Maybe things like the latexify need to go into a separate module if we want it's dependencies to be truly optional



class Array(pybamm.Symbol):
"""
Expand Down Expand Up @@ -36,13 +41,13 @@ class Array(pybamm.Symbol):

def __init__(
self,
entries,
name=None,
domain=None,
auxiliary_domains=None,
domains=None,
entries_string=None,
):
entries: np.ndarray | list[float] | csr_matrix,
name: str | None = None,
domain: list[str] | str | None = None,
auxiliary_domains: dict[str, str] | None = None,
domains: dict[str, list[str] | str] | None = None,
entries_string: str | None = None,
) -> None:
# if
if isinstance(entries, list):
entries = np.array(entries)
Expand All @@ -59,8 +64,6 @@ def __init__(

@classmethod
def _from_json(cls, snippet: dict):
instance = cls.__new__(cls)

if isinstance(snippet["entries"], dict):
matrix = csr_matrix(
(
Expand All @@ -73,14 +76,12 @@ def _from_json(cls, snippet: dict):
else:
matrix = snippet["entries"]

instance.__init__(
return cls(
matrix,
name=snippet["name"],
domains=snippet["domains"],
)

return instance

@property
def entries(self):
return self._entries
Expand All @@ -100,7 +101,7 @@ def entries_string(self):
return self._entries_string

@entries_string.setter
def entries_string(self, value):
def entries_string(self, value: None | tuple):
# We must include the entries in the hash, since different arrays can be
# indistinguishable by class, name and domain alone
# Slightly different syntax for sparse and non-sparse matrices
Expand All @@ -110,10 +111,10 @@ def entries_string(self, value):
entries = self._entries
if issparse(entries):
dct = entries.__dict__
self._entries_string = ["shape", str(dct["_shape"])]
entries_string = ["shape", str(dct["_shape"])]
for key in ["data", "indices", "indptr"]:
self._entries_string += [key, dct[key].tobytes()]
self._entries_string = tuple(self._entries_string)
entries_string += [key, dct[key].tobytes()]
self._entries_string = tuple(entries_string)
# self._entries_string = str(entries.__dict__)
else:
self._entries_string = (entries.tobytes(),)
Expand All @@ -124,7 +125,7 @@ def set_id(self):
(self.__class__, self.name, *self.entries_string, *tuple(self.domain))
)

def _jac(self, variable):
def _jac(self, variable) -> pybamm.Matrix:
"""See :meth:`pybamm.Symbol._jac()`."""
# Return zeros of correct size
jac = csr_matrix((self.size, variable.evaluation_array.count(True)))
Expand All @@ -139,15 +140,15 @@ def create_copy(self):
entries_string=self.entries_string,
)

def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None):
def _base_evaluate(self, t, y, y_dot, inputs):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
return self._entries

def is_constant(self):
"""See :meth:`pybamm.Symbol.is_constant()`."""
return True

def to_equation(self):
def to_equation(self) -> sympy.Array:
"""Returns the value returned by the node when evaluated."""
sympy = have_optional_dependency("sympy")
entries_list = self.entries.tolist()
Expand Down Expand Up @@ -178,7 +179,7 @@ def to_json(self):
return json_dict


def linspace(start, stop, num=50, **kwargs):
def linspace(start: float, stop: float, num: int = 50, **kwargs) -> pybamm.Array:
"""
Creates a linearly spaced array by calling `numpy.linspace` with keyword
arguments 'kwargs'. For a list of 'kwargs' see the
Expand All @@ -187,7 +188,9 @@ def linspace(start, stop, num=50, **kwargs):
return pybamm.Array(np.linspace(start, stop, num, **kwargs))


def meshgrid(x, y, **kwargs):
def meshgrid(
x: pybamm.Array, y: pybamm.Array, **kwargs
) -> tuple[pybamm.Array, pybamm.Array]:
"""
Return coordinate matrices as from coordinate vectors by calling
`numpy.meshgrid` with keyword arguments 'kwargs'. For a list of 'kwargs'
Expand Down
61 changes: 39 additions & 22 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#
# Classes and methods for averaging
#
from __future__ import annotations
from typing import Callable
import pybamm


Expand All @@ -14,13 +16,19 @@ class _BaseAverage(pybamm.Integral):
The child node
"""

def __init__(self, child, name, integration_variable):
def __init__(
self,
child: pybamm.Symbol,
name: str,
integration_variable: list[pybamm.IndependentVariable]
| pybamm.IndependentVariable,
) -> None:
super().__init__(child, integration_variable)
self.name = name


class XAverage(_BaseAverage):
def __init__(self, child):
def __init__(self, child: pybamm.Symbol) -> None:
if all(n in child.domain[0] for n in ["negative", "particle"]):
x = pybamm.standard_spatial_vars.x_n
elif all(n in child.domain[0] for n in ["positive", "particle"]):
Expand All @@ -30,56 +38,60 @@ def __init__(self, child):
integration_variable = x
super().__init__(child, "x-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return x_average(child)


class YZAverage(_BaseAverage):
def __init__(self, child):
def __init__(self, child: pybamm.Symbol) -> None:
y = pybamm.standard_spatial_vars.y
z = pybamm.standard_spatial_vars.z
integration_variable = [y, z]
integration_variable: list[pybamm.IndependentVariable] = [y, z]
super().__init__(child, "yz-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return yz_average(child)


class ZAverage(_BaseAverage):
def __init__(self, child):
integration_variable = [pybamm.standard_spatial_vars.z]
def __init__(self, child: pybamm.Symbol) -> None:
integration_variable: list[pybamm.IndependentVariable] = [
pybamm.standard_spatial_vars.z
]
super().__init__(child, "z-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return z_average(child)


class RAverage(_BaseAverage):
def __init__(self, child):
integration_variable = [pybamm.SpatialVariable("r", child.domain)]
def __init__(self, child: pybamm.Symbol) -> None:
integration_variable: list[pybamm.IndependentVariable] = [
pybamm.SpatialVariable("r", child.domain)
]
super().__init__(child, "r-average", integration_variable)

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return r_average(child)


class SizeAverage(_BaseAverage):
def __init__(self, child, f_a_dist):
def __init__(self, child: pybamm.Symbol, f_a_dist) -> None:
R = pybamm.SpatialVariable("R", domains=child.domains, coord_sys="cartesian")
integration_variable = [R]
integration_variable: list[pybamm.IndependentVariable] = [R]
super().__init__(child, "size-average", integration_variable)
self.f_a_dist = f_a_dist

def _unary_new_copy(self, child):
def _unary_new_copy(self, child: pybamm.Symbol):
"""See :meth:`UnaryOperator._unary_new_copy()`."""
return size_average(child, f_a_dist=self.f_a_dist)


def x_average(symbol):
def x_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the x-direction.

Expand Down Expand Up @@ -168,7 +180,7 @@ def x_average(symbol):
return XAverage(symbol)


def z_average(symbol):
def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the z-direction.

Expand Down Expand Up @@ -205,7 +217,7 @@ def z_average(symbol):
return ZAverage(symbol)


def yz_average(symbol):
def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the y-z-direction.

Expand Down Expand Up @@ -239,11 +251,11 @@ def yz_average(symbol):
return YZAverage(symbol)


def xyz_average(symbol):
def xyz_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
return yz_average(x_average(symbol))


def r_average(symbol):
def r_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
"""
Convenience function for creating an average in the r-direction.

Expand Down Expand Up @@ -286,7 +298,9 @@ def r_average(symbol):
return RAverage(symbol)


def size_average(symbol, f_a_dist=None):
def size_average(
symbol: pybamm.Symbol, f_a_dist: pybamm.Symbol | None = None
) -> pybamm.Symbol:
"""Convenience function for averaging over particle size R using the area-weighted
particle-size distribution.

Expand Down Expand Up @@ -339,7 +353,10 @@ def size_average(symbol, f_a_dist=None):
return SizeAverage(symbol, f_a_dist)


def _sum_of_averages(symbol, average_function):
def _sum_of_averages(
symbol: pybamm.Addition | pybamm.Subtraction,
average_function: Callable[[pybamm.Symbol], pybamm.Symbol],
):
if isinstance(symbol, pybamm.Addition):
return average_function(symbol.left) + average_function(symbol.right)
elif isinstance(symbol, pybamm.Subtraction):
Expand Down
Loading
Loading