diff --git a/FESTIM/generic_simulation.py b/FESTIM/generic_simulation.py index afc266fa8..a56188503 100644 --- a/FESTIM/generic_simulation.py +++ b/FESTIM/generic_simulation.py @@ -218,6 +218,11 @@ def initialise(self): raise AttributeError("dt must be None in steady state simulations") if self.settings.transient and self.dt is None: raise AttributeError("dt must be provided in transient simulations") + + # initialise dt + if self.settings.transient: + self.dt.initialise_value() + self.h_transport_problem = HTransportProblem( self.mobile, self.traps, self.T, self.settings, self.initial_conditions ) @@ -328,7 +333,7 @@ def iterate(self): # avoid t > final_time next_time = self.t + float(self.dt.value) - if next_time > self.settings.final_time: + if next_time > self.settings.final_time and self.t != self.settings.final_time: self.dt.value.assign(self.settings.final_time - self.t) def display_time(self): diff --git a/FESTIM/stepsize.py b/FESTIM/stepsize.py index 7e439e7ce..bcbc20a7a 100644 --- a/FESTIM/stepsize.py +++ b/FESTIM/stepsize.py @@ -40,7 +40,14 @@ def __init__( "stepsize_stop_max": stepsize_stop_max, "dt_min": dt_min, } - self.value = f.Constant(initial_value, name="dt") + self.initial_value = initial_value + self.value = None + self.initialise_value() + + def initialise_value(self): + """Creates a fenics.Constant object initialised with self.initial_value + and stores it in self.value""" + self.value = f.Constant(self.initial_value, name="dt") def adapt(self, t, nb_it, converged): """Changes the stepsize based on convergence. diff --git a/Tests/simulation/test_initialise.py b/Tests/simulation/test_initialise.py index fa3d492ff..d72ae2c4b 100644 --- a/Tests/simulation/test_initialise.py +++ b/Tests/simulation/test_initialise.py @@ -49,3 +49,25 @@ def test_initialise_sets_t_to_zero(): # check that my_model.t is reinitialised to zero assert my_model.t == 0 + + +def test_initialise_initialise_dt(): + """Creates a Simulation object and checks that .initialise() sets + the value attribute of the dt attribute to dt.initial_value + """ + # build + my_model = F.Simulation() + my_model.mesh = F.MeshFromVertices([1, 2, 3]) + my_model.materials = F.Material(id=1, D_0=1, E_D=0, thermal_cond=1) + my_model.T = F.Temperature(100) + my_model.dt = F.Stepsize(initial_value=3) + my_model.settings = F.Settings( + absolute_tolerance=1e-10, relative_tolerance=1e-10, final_time=4 + ) + my_model.dt.value.assign(26) + + # run + my_model.initialise() + + # test + assert my_model.dt.value(2) == 3