diff --git a/src/bin/bench.rs b/src/bin/bench.rs index a82560e..37c2528 100644 --- a/src/bin/bench.rs +++ b/src/bin/bench.rs @@ -9,8 +9,9 @@ fn main() { let dt = 0.01; let eom = model::Lorenz63::default(); let teo = explicit::rk4(eom, dt); - let mut x = arr1(&[1.0, 0.0, 0.0]); + let mut buf = explicit::RK4Buffer::new_buffer(&teo); + let mut x: Array1 = arr1(&[1.0, 0.0, 0.0]); for _ in 0..100_000_000 { - teo.iterate(&mut x); + teo.iterate_buf(&mut x, &mut buf); } } diff --git a/src/explicit.rs b/src/explicit.rs index b76acde..a1f08c4 100644 --- a/src/explicit.rs +++ b/src/explicit.rs @@ -148,27 +148,34 @@ pub struct RK4Buffer { k3: Array, } -impl TimeEvolutionBuffered for RK4 +impl RK4Buffer + where A: Scalar, + D: Dimension +{ + pub fn new_buffer(t: &T) -> RK4Buffer + where T: ModelSize + { + RK4Buffer { + x: Array::zeros(t.model_size()), + k1: Array::zeros(t.model_size()), + k2: Array::zeros(t.model_size()), + k3: Array::zeros(t.model_size()), + } + } +} + +impl TimeEvolutionBuffered> for RK4 where A: Scalar, S: DataMut, D: Dimension, F: Explicit { type Scalar = F::Scalar; - type Buffer = RK4Buffer; - fn get_buffer(&self) -> Self::Buffer { - RK4Buffer { - x: Array::zeros(self.model_size()), - k1: Array::zeros(self.model_size()), - k2: Array::zeros(self.model_size()), - k3: Array::zeros(self.model_size()), - } - } fn iterate_buf<'a>(&self, mut x: &'a mut ArrayBase, - mut buf: &mut Self::Buffer) + mut buf: &mut RK4Buffer) -> &'a mut ArrayBase { let dt = self.dt; let dt_2 = self.dt * into_scalar(0.5); diff --git a/src/traits.rs b/src/traits.rs index 16a4a1a..967141a 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -45,21 +45,6 @@ pub trait TimeEvolutionBase: ModelSize + TimeStep fn iterate<'a>(&self, &'a mut ArrayBase) -> &'a mut ArrayBase; } -/// Time-evolution operator with buffer -pub trait TimeEvolutionBuffered: ModelSize + TimeStep - where S: DataMut, - D: Dimension -{ - type Scalar: Scalar; - type Buffer; - fn get_buffer(&self) -> Self::Buffer; - /// calculate next step - fn iterate_buf<'a>(&self, - &'a mut ArrayBase, - &mut Self::Buffer) - -> &'a mut ArrayBase; -} - pub trait TimeEvolution : TimeEvolutionBase, D, Scalar = A, Time = A::Real> + TimeEvolutionBase, D, Scalar = A, Time = A::Real> @@ -68,3 +53,13 @@ pub trait TimeEvolution D: Dimension { } + +/// Time-evolution operator with buffer +pub trait TimeEvolutionBuffered: ModelSize + TimeStep + where S: DataMut, + D: Dimension +{ + type Scalar: Scalar; + /// calculate next step + fn iterate_buf<'a>(&self, &'a mut ArrayBase, &mut Buffer) -> &'a mut ArrayBase; +}