Skip to content

Commit

Permalink
#3200 fix another failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
brosaplanella committed Aug 3, 2023
1 parent f976923 commit 07feb0a
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions tests/unit/test_solvers/test_processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,22 +926,22 @@ def test_processed_var_2D_unknown_domain(self):
domain=["domain B"],
auxiliary_domains={"secondary": ["domain A"]},
)
a = pybamm.SpatialVariable("a", domain=["domain A"])
b = pybamm.SpatialVariable(
"b",
x = pybamm.SpatialVariable("x", domain=["domain A"])
z = pybamm.SpatialVariable(
"z",
domain=["domain B"],
auxiliary_domains={"secondary": ["domain A"]},
)

geometry = {
"domain A": {a: {"min": 0, "max": 1}},
"domain B": {b: {"min": 0, "max": 1}},
"domain A": {x: {"min": 0, "max": 1}},
"domain B": {z: {"min": 0, "max": 1}},
}
submesh_types = {
"domain A": pybamm.Uniform1DSubMesh,
"domain B": pybamm.Uniform1DSubMesh,
}
var_pts = {a: 10, b: 10}
var_pts = {x: 10, z: 20}
mesh = pybamm.Mesh(geometry, submesh_types, var_pts)

spatial_methods = {
Expand All @@ -951,20 +951,20 @@ def test_processed_var_2D_unknown_domain(self):

disc = pybamm.Discretisation(mesh, spatial_methods)
disc.set_variable_slices([var])
a_sol = disc.process_symbol(a).entries[:, 0]
b_sol = disc.process_symbol(b).entries[:, 0]
x_sol = disc.process_symbol(x).entries[:, 0]
z_sol = disc.process_symbol(z).entries[:, 0]
# Keep only the first iteration of entries
b_sol = b_sol[: len(b_sol) // len(a_sol)]
z_sol = z_sol[: len(z_sol) // len(x_sol)]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
y_sol = np.ones(len(a_sol) * len(b_sol))[:, np.newaxis] * np.linspace(0, 5)
y_sol = np.ones(len(x_sol) * len(z_sol))[:, np.newaxis] * np.linspace(0, 5)

var_casadi = to_casadi(var_sol, y_sol)
model = pybamm.BaseModel()
model.geometry = pybamm.Geometry(
{
"domain A": {a: {"min": 0, "max": 1}},
"domain B": {b: {"min": 0, "max": 1}},
"domain A": {x: {"min": 0, "max": 1}},
"domain B": {z: {"min": 0, "max": 1}},
}
)
processed_var = pybamm.ProcessedVariable(
Expand All @@ -975,22 +975,28 @@ def test_processed_var_2D_unknown_domain(self):
)
# 3 vectors
np.testing.assert_array_equal(
processed_var(t_sol, a_sol, b_sol).shape, (10, 40, 50)
processed_var(t=t_sol, x=x_sol, z=z_sol).shape, (20, 10, 50)
)
np.testing.assert_array_almost_equal(
processed_var(t_sol, a_sol, b_sol),
np.reshape(y_sol, [len(b_sol), len(a_sol), len(t_sol)]),
processed_var(t_sol, x=x_sol, z=z_sol),
np.reshape(y_sol, [len(z_sol), len(x_sol), len(t_sol)]),
)
# 2 vectors, 1 scalar
np.testing.assert_array_equal(processed_var(0.5, a_sol, b_sol).shape, (10, 40))
np.testing.assert_array_equal(processed_var(t_sol, 0.2, b_sol).shape, (10, 50))
np.testing.assert_array_equal(processed_var(t_sol, a_sol, 0.5).shape, (40, 50))
np.testing.assert_array_equal(
processed_var(t=0.5, x=x_sol, z=z_sol).shape, (20, 10)
)
np.testing.assert_array_equal(
processed_var(t=t_sol, x=0.2, z=z_sol).shape, (20, 50)
)
np.testing.assert_array_equal(
processed_var(t=t_sol, x=x_sol, z=0.5).shape, (10, 50)
)
# 1 vectors, 2 scalar
np.testing.assert_array_equal(processed_var(0.5, 0.2, b_sol).shape, (10,))
np.testing.assert_array_equal(processed_var(0.5, a_sol, 0.5).shape, (40,))
np.testing.assert_array_equal(processed_var(t_sol, 0.2, 0.5).shape, (50,))
np.testing.assert_array_equal(processed_var(t=0.5, x=0.2, z=z_sol).shape, (20,))
np.testing.assert_array_equal(processed_var(t=0.5, x=x_sol, z=0.5).shape, (10,))
np.testing.assert_array_equal(processed_var(t=t_sol, x=0.2, z=0.5).shape, (50,))
# 3 scalars
np.testing.assert_array_equal(processed_var(0.2, 0.2, 0.2).shape, ())
np.testing.assert_array_equal(processed_var(t=0.2, x=0.2, z=0.2).shape, ())

def test_3D_raises_error(self):
var = pybamm.Variable(
Expand Down

0 comments on commit 07feb0a

Please sign in to comment.