Skip to content

Commit

Permalink
Merge pull request #32 from termoshtt/buffered
Browse files Browse the repository at this point in the history
TimeEvolution with buffer
  • Loading branch information
termoshtt authored Jul 26, 2017
2 parents 1fc45fb + 9916d38 commit 38113b1
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/bin/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ fn main() {
let dt = 0.01;
let eom = model::Lorenz63::default();
let teo = explicit::rk4(eom, dt);
let mut x = arr1(&[1.0, 0.0, 0.0]);
let mut buf = explicit::RK4Buffer::new_buffer(&teo);
let mut x: Array1<f64> = arr1(&[1.0, 0.0, 0.0]);
for _ in 0..100_000_000 {
teo.iterate(&mut x);
teo.iterate_buf(&mut x, &mut buf);
}
}
70 changes: 70 additions & 0 deletions src/explicit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,73 @@ impl<A, S, D, F> TimeEvolutionBase<S, D> for RK4<F, F::Time>
k4
}
}

pub struct RK4Buffer<A, D> {
x: Array<A, D>,
k1: Array<A, D>,
k2: Array<A, D>,
k3: Array<A, D>,
}

impl<A, D> RK4Buffer<A, D>
where A: Scalar,
D: Dimension
{
pub fn new_buffer<T>(t: &T) -> RK4Buffer<A, D>
where T: ModelSize<D>
{
RK4Buffer {
x: Array::zeros(t.model_size()),
k1: Array::zeros(t.model_size()),
k2: Array::zeros(t.model_size()),
k3: Array::zeros(t.model_size()),
}
}
}

impl<A, S, D, F> TimeEvolutionBufferedBase<S, D, RK4Buffer<A, D>> for RK4<F, F::Time>
where A: Scalar,
S: DataMut<Elem = A>,
D: Dimension,
F: Explicit<S, D, Time = A::Real, Scalar = A>
{
type Scalar = F::Scalar;

fn iterate_buf<'a>(&self,
mut x: &'a mut ArrayBase<S, D>,
mut buf: &mut RK4Buffer<A, D>)
-> &'a mut ArrayBase<S, D> {
let dt = self.dt;
let dt_2 = self.dt * into_scalar(0.5);
let dt_6 = self.dt / into_scalar(6.0);
buf.x.zip_mut_with(x, |buf, x| *buf = *x);
// k1
let mut k1 = self.f.rhs(x);
buf.k1.zip_mut_with(k1, |buf, k1| *buf = *k1);
Zip::from(&mut *k1)
.and(&buf.x)
.apply(|k1, &x| { *k1 = k1.mul_real(dt_2) + x; });
// k2
let mut k2 = self.f.rhs(k1);
buf.k2.zip_mut_with(k2, |buf, k| *buf = *k);
Zip::from(&mut *k2)
.and(&buf.x)
.apply(|k2, &x| { *k2 = x + k2.mul_real(dt_2); });
// k3
let mut k3 = self.f.rhs(k2);
buf.k3.zip_mut_with(k3, |buf, k| *buf = *k);
Zip::from(&mut *k3)
.and(&buf.x)
.apply(|k3, &x| { *k3 = x + k3.mul_real(dt); });
let mut k4 = self.f.rhs(k3);
Zip::from(&mut *k4)
.and(&buf.x)
.and(&buf.k1)
.and(&buf.k2)
.and(&buf.k3)
.apply(|k4, &x, &k1, &k2, &k3| {
*k4 = x + (k1 + (k2 + k3).mul_real(into_scalar(2.0)) + *k4).mul_real(dt_6);
});
k4
}
}
20 changes: 19 additions & 1 deletion src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ pub trait TimeEvolutionBase<S, D>: ModelSize<D> + TimeStep
fn iterate<'a>(&self, &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
}


pub trait TimeEvolution<A, D>
: TimeEvolutionBase<OwnedRepr<A>, D, Scalar = A, Time = A::Real>
+ TimeEvolutionBase<OwnedRcRepr<A>, D, Scalar = A, Time = A::Real>
Expand All @@ -54,3 +53,22 @@ pub trait TimeEvolution<A, D>
D: Dimension
{
}

/// Time-evolution operator with buffer
pub trait TimeEvolutionBufferedBase<S, D, Buffer>: ModelSize<D> + TimeStep
where S: DataMut,
D: Dimension
{
type Scalar: Scalar;
/// calculate next step
fn iterate_buf<'a>(&self, &'a mut ArrayBase<S, D>, &mut Buffer) -> &'a mut ArrayBase<S, D>;
}

pub trait TimeEvolutionBuffered<A, D, Buffer>
: TimeEvolutionBufferedBase<OwnedRepr<A>, D, Buffer, Scalar = A, Time = A::Real>
+ TimeEvolutionBufferedBase<OwnedRcRepr<A>, D, Buffer, Scalar = A, Time = A::Real>
+ for<'a> TimeEvolutionBufferedBase<ViewRepr<&'a mut A>, D, Buffer, Scalar = A, Time = A::Real>
where A: Scalar,
D: Dimension
{
}

0 comments on commit 38113b1

Please sign in to comment.