Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delay xarray.DataArray initialization #3862

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __init__(
+ "(note processing of 3D variables is not yet implemented)"
)

# xr_data_array is initialized when needed
self._xr_data_array = None

def initialise_0D(self):
# initialise empty array of the correct size
entries = np.empty(len(self.t_pts))
Expand All @@ -130,8 +133,9 @@ def initialise_0D(self):
entries, self.t_pts, initial=float(self.cumtrapz_ic)
)

# set up interpolation
self._xr_data_array = xr.DataArray(entries, coords=[("t", self.t_pts)])
# save attributes for interpolation
self.entries_for_interp = entries
self.coords_for_interp = {"t": self.t_pts}

self.entries = entries
self.dimensions = 0
Expand Down Expand Up @@ -185,11 +189,9 @@ def initialise_1D(self, fixed_t=False):
# Set first_dim_pts to edges for nicer plotting
self.first_dim_pts = edges

# set up interpolation
self._xr_data_array = xr.DataArray(
entries_for_interp,
coords=[(self.first_dimension, pts_for_interp), ("t", self.t_pts)],
)
# save attributes for interpolation
self.entries_for_interp = entries_for_interp
self.coords_for_interp = {self.first_dimension: pts_for_interp, "t": self.t_pts}

def initialise_2D(self):
"""
Expand Down Expand Up @@ -289,15 +291,13 @@ def initialise_2D(self):
self.first_dim_pts = first_dim_edges
self.second_dim_pts = second_dim_edges

# set up interpolation
self._xr_data_array = xr.DataArray(
entries_for_interp,
coords={
self.first_dimension: first_dim_pts_for_interp,
self.second_dimension: second_dim_pts_for_interp,
"t": self.t_pts,
},
)
# save attributes for interpolation
self.entries_for_interp = entries_for_interp
self.coords_for_interp = {
self.first_dimension: first_dim_pts_for_interp,
self.second_dimension: second_dim_pts_for_interp,
"t": self.t_pts,
}

def initialise_2D_scikit_fem(self):
y_sol = self.mesh.edges["y"]
Expand Down Expand Up @@ -331,11 +331,9 @@ def initialise_2D_scikit_fem(self):
self.first_dim_pts = y_sol
self.second_dim_pts = z_sol

# set up interpolation
self._xr_data_array = xr.DataArray(
entries,
coords={"y": y_sol, "z": z_sol, "t": self.t_pts},
)
# save attributes for interpolation
self.entries_for_interp = entries
self.coords_for_interp = {"y": y_sol, "z": z_sol, "t": self.t_pts}

def _process_spatial_variable_names(self, spatial_variable):
if len(spatial_variable) == 0:
Expand Down Expand Up @@ -366,11 +364,23 @@ def _process_spatial_variable_names(self, spatial_variable):
f"Spatial variable name not recognized for {spatial_variable}"
)

def _initialize_xr_data_array(self):
"""
Initialize the xarray DataArray for interpolation. We don't do this by
default as it has some overhead (~75 us) and sometimes we only need the entries
of the processed variable, not the xarray object for interpolation.
"""
entries = self.entries_for_interp
coords = self.coords_for_interp
self._xr_data_array = xr.DataArray(entries, coords=coords)

def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
"""
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),
using interpolation
"""
if self._xr_data_array is None:
self._initialize_xr_data_array()
kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
# Remove any None arguments
kwargs = {key: value for key, value in kwargs.items() if value is not None}
Expand Down
Loading