Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor improve for pmsim #120

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 188 additions & 26 deletions russell_lab/src/vector/num_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::fmt::{self, Write};
use std::ops::{Index, IndexMut};
use std::ops::{Index, IndexMut, MulAssign};

/// Implements a vector with numeric components for linear algebra
///
Expand Down Expand Up @@ -90,15 +90,15 @@ use std::ops::{Index, IndexMut};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
#[serde(bound(deserialize = "Vec<T>: Deserialize<'de>"))]
data: Vec<T>,
}

impl<T> NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
/// Creates a new (zeroed) vector
///
Expand Down Expand Up @@ -387,17 +387,17 @@ where

/// Returns the i-th component
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let u = NumVector::<f64>::from(&[1.0, 2.0]);
/// assert_eq!(u.get(1), 2.0);
/// ```
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
#[inline]
pub fn get(&self, i: usize) -> T {
assert!(i < self.data.len());
Expand All @@ -406,6 +406,10 @@ where

/// Change the i-th component
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
Expand All @@ -418,16 +422,116 @@ where
/// └ ┘";
/// assert_eq!(format!("{}", u), correct);
/// ```
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
#[inline]
pub fn set(&mut self, i: usize, value: T) {
assert!(i < self.data.len());
self.data[i] = value;
}

/// Copy another vector into this one
///
/// # Panics
///
/// This function may panic if the other vector has a different length than this one
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
/// u.set_vector(&[-3.0, -4.0]);
/// let correct = "┌ ┐\n\
/// │ -3 │\n\
/// │ -4 │\n\
/// └ ┘";
/// assert_eq!(format!("{}", u), correct);
/// ```
pub fn set_vector(&mut self, other: &[T]) {
assert_eq!(other.len(), self.data.len());
self.data.copy_from_slice(other);
}

/// Splits this vector into another two vectors
///
/// **Requirements:** `u.len() + v.len() == self.len()`
///
/// This function is the opposite of [NumVector::join2()]
///
/// # Panics
///
/// This function may panic if the sum of the lengths of u and v are different that this vector's length
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let w = NumVector::<f64>::from(&[1.0, 2.0, 3.0]);
/// let mut u = NumVector::<f64>::new(2);
/// let mut v = NumVector::<f64>::new(1);
///
/// w.split2(u.as_mut_data(), v.as_mut_data());
///
/// assert_eq!(u.as_data(), &[1.0, 2.0]);
/// assert_eq!(v.as_data(), &[3.0]);
/// ```
pub fn split2(&self, u: &mut [T], v: &mut [T]) {
assert_eq!(u.len() + v.len(), self.data.len());
u.copy_from_slice(&self.data[..u.len()]);
v.copy_from_slice(&self.data[u.len()..]);
}

/// Joins two vectors into this one
///
/// **Requirements:** `u.len() + v.len() == self.len()`
///
/// This function is the opposite of [NumVector::split2()]
///
/// # Panics
///
/// This function may panic if the sum of the lengths of u and v are different that this vector's length
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let mut w = NumVector::<f64>::new(3);
/// let u = NumVector::<f64>::from(&[1.0, 2.0]);
/// let v = NumVector::<f64>::from(&[3.0]);
///
/// w.join2(u.as_data(), v.as_data());
///
/// assert_eq!(w.as_data(), &[1.0, 2.0, 3.0]);
/// ```
pub fn join2(&mut self, u: &[T], v: &[T]) {
assert_eq!(u.len() + v.len(), self.data.len());
(&mut self.data[..u.len()]).copy_from_slice(u);
(&mut self.data[u.len()..]).copy_from_slice(v);
}

/// Scales this vector
///
/// ```text
/// u := alpha * u
/// ```
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
/// u.scale(2.0);
/// let correct = "┌ ┐\n\
/// │ 2 │\n\
/// │ 4 │\n\
/// └ ┘";
/// assert_eq!(format!("{}", u), correct);
/// ```
pub fn scale(&mut self, alpha: T) {
for i in 0..self.data.len() {
self.data[i] *= alpha;
}
}

/// Applies a function over all components of this vector
///
/// ```text
Expand Down Expand Up @@ -524,7 +628,7 @@ where

impl<T> fmt::Display for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize + fmt::Display,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize + fmt::Display,
{
/// Generates a string representation of the NumVector
///
Expand Down Expand Up @@ -584,6 +688,10 @@ where

/// Allows to access NumVector components using indices
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
Expand All @@ -593,13 +701,9 @@ where
/// assert_eq!(u[1], 1.2);
/// assert_eq!(u[2], 2.0);
/// ```
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
impl<T> Index<usize> for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Output = T;
#[inline]
Expand All @@ -610,6 +714,10 @@ where

/// Allows to change NumVector components using indices
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
Expand All @@ -622,13 +730,9 @@ where
/// assert_eq!(u[1], 11.2);
/// assert_eq!(u[2], 22.0);
/// ```
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
impl<T> IndexMut<usize> for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
#[inline]
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
Expand All @@ -649,7 +753,7 @@ where
/// ```
impl<T> IntoIterator for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Item = T;
type IntoIter = std::vec::IntoIter<Self::Item>;
Expand All @@ -673,7 +777,7 @@ where
/// ```
impl<'a, T> IntoIterator for &'a NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Item = &'a T;
type IntoIter = std::slice::Iter<'a, T>;
Expand All @@ -698,7 +802,7 @@ where
/// ```
impl<'a, T> IntoIterator for &'a mut NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Item = &'a mut T;
type IntoIter = std::slice::IterMut<'a, T>;
Expand All @@ -710,7 +814,7 @@ where
/// Allows accessing NumVector as an Array1D
impl<'a, T: 'a> AsArray1D<'a, T> for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
#[inline]
fn size(&self) -> usize {
Expand Down Expand Up @@ -887,6 +991,64 @@ mod tests {
assert_eq!(u.data, &[-1.0, -2.0]);
}

#[test]
#[should_panic]
fn set_vector_panics_on_wrong_len() {
let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
u.set_vector(&[8.0, 9.0, 10.0]);
}

#[test]
fn set_vector_works() {
let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
u.set_vector(&[8.0, 9.0]);
assert_eq!(u.data, &[8.0, 9.0]);
}

#[test]
#[should_panic]
fn split2_panics_on_wrong_lengths() {
let w = NumVector::<f64>::from(&[1.0, 2.0, 3.0]);
let mut u = NumVector::<f64>::new(2);
let mut v = NumVector::<f64>::new(2); // WRONG length
w.split2(u.as_mut_data(), v.as_mut_data());
}

#[test]
fn split2_works() {
let w = NumVector::<f64>::from(&[4.0, 5.0, -6.0]);
let mut u = NumVector::<f64>::new(2);
let mut v = NumVector::<f64>::new(1);
w.split2(u.as_mut_data(), v.as_mut_data());
assert_eq!(u.as_data(), &[4.0, 5.0]);
assert_eq!(v.as_data(), &[-6.0]);
}

#[test]
#[should_panic]
fn join2_panics_on_wrong_lengths() {
let mut w = NumVector::<f64>::new(2); // WRONG length
let u = NumVector::<f64>::from(&[1.0, 2.0]);
let v = NumVector::<f64>::from(&[3.0]);
w.join2(u.as_data(), v.as_data());
}

#[test]
fn join2_works() {
let mut w = NumVector::<f64>::new(4);
let u = NumVector::<f64>::from(&[9.0, -1.0, 7.0]);
let v = NumVector::<f64>::from(&[8.0]);
w.join2(u.as_data(), v.as_data());
assert_eq!(w.as_data(), &[9.0, -1.0, 7.0, 8.0]);
}

#[test]
fn scale_works() {
let mut u = NumVector::<f64>::from(&[2.0, 4.0]);
u.scale(0.5);
assert_eq!(u.data, &[1.0, 2.0]);
}

#[test]
fn map_works() {
let mut u = NumVector::<f64>::from(&[-1.0, -2.0, -3.0]);
Expand Down
2 changes: 1 addition & 1 deletion russell_ode/src/ode_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<'a, A> OdeSolver<'a, A> {
/// # Generics
///
/// * `A` -- generic argument to assist in the f(x,y) and Jacobian functions.
/// It may be simply [NoArgs] indicating that no arguments are needed.
/// It may be simply [crate::NoArgs] indicating that no arguments are needed.
pub fn new(params: Params, system: System<'a, A>) -> Result<Self, StrError>
where
A: 'a,
Expand Down
10 changes: 9 additions & 1 deletion russell_ode/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl<'a, A> Output<'a, A> {
if self.with_dense_output() {
if let Some(h_out) = self.dense_h_out {
// uniform spacing
let n = ((x1 - x0) / h_out) as usize + 1;
let n = usize::max(2, ((x1 - x0) / h_out) as usize + 1); // at least 2 (first and last) are required
if self.dense_x.len() != n {
self.dense_x.resize(n, 0.0);
}
Expand Down Expand Up @@ -808,6 +808,14 @@ mod tests {
assert_eq!(y0_out.len(), 4);
}

#[test]
fn initialize_with_dense_output_works_at_least_two_stations() {
let mut out = Output::<'_, NoArgs>::new();
out.set_dense_h_out(0.5).unwrap().set_dense_recording(&[0]);
out.initialize(0.99, 1.0, false).unwrap();
assert_eq!(out.dense_x.len(), 2);
}

#[test]
fn initialize_with_step_output_works() {
let mut out = Output::<'_, NoArgs>::new();
Expand Down
Loading