Skip to content

Commit

Permalink
#3200 fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brosaplanella committed Aug 3, 2023
1 parent 1c78545 commit 753027f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 50 deletions.
40 changes: 25 additions & 15 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ def __init__(
variables += list(geometry[domain].keys())
self.spatial_variables[domain_level] = variables

self.spatial_variable_names = {
k: self._process_spatial_variable_names(v)
for k, v in self.spatial_variables.items()
}

# Sensitivity starts off uninitialized, only set when called
self._sensitivities = None
self.solution_sensitivities = solution.sensitivities
Expand Down Expand Up @@ -186,6 +181,10 @@ def initialise_1D(self, fixed_t=False):
# assign attributes for reference (either x_sol or r_sol)
self.entries = entries
self.dimensions = 1
self.spatial_variable_names = {
k: self._process_spatial_variable_names(v)
for k, v in self.spatial_variables.items()
}
self.first_dimension = self.spatial_variable_names["primary"]

# assign attributes for reference
Expand Down Expand Up @@ -281,6 +280,11 @@ def initialise_2D(self):
axis=1,
)

self.spatial_variable_names = {
k: self._process_spatial_variable_names(v)
for k, v in self.spatial_variables.items()
}

# Process r-x, x-z, r-R, R-x, or R-z
if self.domain[0].endswith("particle") and self.domains["secondary"][
0
Expand Down Expand Up @@ -374,19 +378,25 @@ def initialise_2D_scikit_fem(self):
def _process_spatial_variable_names(self, spatial_variable):
if len(spatial_variable) == 0:
return None
elif spatial_variable in [["r_n"], ["r_p"]]:

# Extract names
raw_names = []
for var in spatial_variable:
if isinstance(var, str):
raw_names.append(var)
else:
raw_names.append(var.name)

# Rename battery variables to match PyBaMM convention
if all([var.startswith("r") for var in raw_names]):
return "r"
elif spatial_variable in [["x_n"], ["x_s"], ["x_p"], ["x_n", "x_s", "x_p"]]:
elif all([var.startswith("x") for var in raw_names]):
return "x"
elif spatial_variable in [["R_n"], ["R_p"]]:
elif all([var.startswith("R") for var in raw_names]):
return "R"
elif len(spatial_variable) == 1:
if isinstance(spatial_variable[0], str):
return spatial_variable[0]
else:
return spatial_variable[0].name
else: # pragma: no cover
# should not be reached
elif len(raw_names) == 1:
return spatial_variable[0]
else:
raise NotImplementedError(
"Spatial variable name not recognized for {}".format(spatial_variable)
)
Expand Down
8 changes: 2 additions & 6 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def get_cylindrical_discretisation_for_testing(
)


def get_base_model_with_battery_geometry(
include_particles=True, options=None, form_factor="pouch"
):
def get_base_model_with_battery_geometry(**kwargs):
model = pybamm.BaseModel()
model.geometry = pybamm.battery_geometry(
include_particles=include_particles, options=options, form_factor=form_factor
)
model.geometry = pybamm.battery_geometry(**kwargs)
return model
Loading

0 comments on commit 753027f

Please sign in to comment.