diff --git a/CHANGELOG.md b/CHANGELOG.md index fdaf360d6f..b1e093f452 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,9 +3,10 @@ ## Features - Added support for macOS arm64 (M-series) platforms. ([#3789](https://github.com/pybamm-team/PyBaMM/pull/3789)) +- Added the ability to specify a custom solver tolerance in `get_initial_stoichiometries` and related functions ([#3714](https://github.com/pybamm-team/PyBaMM/pull/3714)) +- Modified `step` function to take an array of time `t_eval` as an argument and deprecated use of `npts`. ([#3627](https://github.com/pybamm-team/PyBaMM/pull/3627)) - Renamed "electrode diffusivity" to "particle diffusivity" as a non-breaking change with a deprecation warning ([#3624](https://github.com/pybamm-team/PyBaMM/pull/3624)) - Add support for BPX version 0.4.0 which allows for blended electrodes and user-defined parameters in BPX([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414)) -- Added the ability to specify a custom solver tolerance in `get_initial_stoichiometries` and related functions ([#3714](https://github.com/pybamm-team/PyBaMM/pull/3714)) ## Bug Fixes diff --git a/examples/scripts/SPMe_step.py b/examples/scripts/SPMe_step.py index 90d6f3d017..f277c0e790 100644 --- a/examples/scripts/SPMe_step.py +++ b/examples/scripts/SPMe_step.py @@ -32,12 +32,14 @@ # step model dt = 500 +# t_eval is an array of time in the interval 0 to dt, dt being size of the step. +t_eval = np.array([0, 50, 100, 200, 500]) time = 0 end_time = solution.t[-1] step_solver = pybamm.CasadiSolver() step_solution = None while time < end_time: - step_solution = step_solver.step(step_solution, model, dt=dt, npts=10) + step_solution = step_solver.step(step_solution, model, dt=dt, t_eval=t_eval) time += dt # plot diff --git a/pybamm/simulation.py b/pybamm/simulation.py index a2b260ab43..3dfdac94b5 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -789,7 +789,7 @@ def solve( current_solution, model, dt, - npts=npts, + t_eval=np.linspace(0, dt, npts), save=False, **kwargs, ) @@ -970,7 +970,7 @@ def run_padding_rest(self, kwargs, rest_time, step_solution): step_solution, model, rest_time, - npts=npts, + t_eval=np.linspace(0, rest_time, npts), save=False, **kwargs, ) @@ -978,7 +978,13 @@ def run_padding_rest(self, kwargs, rest_time, step_solution): return step_solution_with_rest def step( - self, dt, solver=None, npts=2, save=True, starting_solution=None, **kwargs + self, + dt, + solver=None, + t_eval=None, + save=True, + starting_solution=None, + **kwargs, ): """ A method to step the model forward one timestep. This method will @@ -990,9 +996,10 @@ def step( The timestep over which to step the solution solver : :class:`pybamm.BaseSolver` The solver to use to solve the model. - npts : int, optional - The number of points at which the solution will be returned during - the step dt. Default is 2 (returns the solution at t0 and t0 + dt). + t_eval : list or numpy.ndarray, optional + An array of times at which to return the solution during the step + (Note: t_eval is the time measured from the start of the step, so should start at 0 and end at dt). + By default, the solution is returned at t0 and t0 + dt. save : bool Turn on to store the solution of all previous timesteps starting_solution : :class:`pybamm.Solution` @@ -1012,7 +1019,12 @@ def step( starting_solution = self._solution self._solution = solver.step( - starting_solution, self.built_model, dt, npts=npts, save=save, **kwargs + starting_solution, + self.built_model, + dt, + t_eval=t_eval, + save=save, + **kwargs, ) return self.solution diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index a2b4c305c2..be26306bbd 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1090,7 +1090,8 @@ def step( old_solution, model, dt, - npts=2, + t_eval=None, + npts=None, inputs=None, save=True, ): @@ -1108,9 +1109,11 @@ def step( initial_conditions dt : numeric type The timestep (in seconds) over which to step the solution - npts : int, optional - The number of points at which the solution will be returned during - the step dt. default is 2 (returns the solution at t0 and t0 + dt). + t_eval : list or numpy.ndarray, optional + An array of times at which to return the solution during the step + (Note: t_eval is the time measured from the start of the step, so should start at 0 and end at dt). + By default, the solution is returned at t0 and t0 + dt. + npts : deprecated inputs : dict, optional Any input parameters to pass to the model when solving save : bool @@ -1149,10 +1152,28 @@ def step( f"Step time must be at least {pybamm.TimerTime(step_start_offset)}" ) + # Raise deprecation warning for npts and convert it to t_eval + if npts is not None: + warnings.warn( + "The 'npts' parameter is deprecated, use 't_eval' instead.", + DeprecationWarning, + stacklevel=2, + ) + t_eval = np.linspace(0, dt, npts) + + if t_eval is not None: + # Checking if t_eval lies within range + if t_eval[0] != 0 or t_eval[-1] != dt: + raise pybamm.SolverError( + "Elements inside array t_eval must lie in the closed interval 0 to dt" + ) + + else: + t_eval = np.array([0, dt]) + t_start = old_solution.t[-1] + t_eval = t_start + t_eval t_end = t_start + dt - # Calculate t_eval - t_eval = np.linspace(t_start, t_end, npts) if t_start == 0: t_start_shifted = t_start diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 577e50e68b..dcd4b5b856 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -75,6 +75,22 @@ def test_nonmonotonic_teval(self): ): solver.step(None, model, dt) + # Checking if array t_eval lies within range + dt = 2 + t_eval = np.array([0, 1]) + with self.assertRaisesRegex( + pybamm.SolverError, + "Elements inside array t_eval must lie in the closed interval 0 to dt", + ): + solver.step(None, model, dt, t_eval=t_eval) + + t_eval = np.array([1, dt]) + with self.assertRaisesRegex( + pybamm.SolverError, + "Elements inside array t_eval must lie in the closed interval 0 to dt", + ): + solver.step(None, model, dt, t_eval=t_eval) + def test_solution_time_length_fail(self): model = pybamm.BaseModel() v = pybamm.Scalar(1)