Skip to content

Commit

Permalink
Better inner product call
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Jan 12, 2025
1 parent b8102bb commit 87ec108
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/operator/interface/array_vector_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl<Item: RlstScalar> LinearSpace for ArrayVectorSpace<Item> {
}

impl<Item: RlstScalar> InnerProductSpace for ArrayVectorSpace<Item> {
fn inner(&self, x: &Self::E, other: &Self::E) -> Self::F {
fn inner_product(&self, x: &Self::E, other: &Self::E) -> Self::F {
x.view().inner(other.view())
}
}
Expand All @@ -92,8 +92,8 @@ impl<Item: RlstScalar> Element for ArrayVectorSpaceElement<Item> {
where
Self: 'b;

fn space(&self) -> &Self::Space {
&self.space
fn space(&self) -> Rc<Self::Space> {
self.space.clone()
}

fn view(&self) -> Self::View<'_> {
Expand Down
6 changes: 3 additions & 3 deletions src/operator/interface/distributed_array_vector_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<'a, C: Communicator, Item: RlstScalar + Equivalence> LinearSpace
impl<C: Communicator, Item: RlstScalar + Equivalence> InnerProductSpace
for DistributedArrayVectorSpace<'_, C, Item>
{
fn inner(&self, x: &Self::E, other: &Self::E) -> Self::F {
fn inner_product(&self, x: &Self::E, other: &Self::E) -> Self::F {
x.view().inner(other.view())
}
}
Expand All @@ -98,8 +98,8 @@ impl<'a, C: Communicator, Item: RlstScalar + Equivalence> Element
where
Self: 'b;

fn space(&self) -> &Self::Space {
&self.space
fn space(&self) -> Rc<Self::Space> {
self.space.clone()
}

fn view(&self) -> Self::View<'_> {
Expand Down
8 changes: 5 additions & 3 deletions src/operator/operations/conjugate_gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ impl<'a, Space: InnerProductSpace, Op: AsApply<Domain = Space, Range = Space>>
// If I write `let rhs_norm = self.rhs.norm()` the compiler thinks that `self.rhs` is a space and
// not an element.
let rhs_norm = <Space as NormedSpace>::norm(&self.operator.range(), self.rhs);
let mut res_inner = <Space as InnerProductSpace>::inner(&self.operator.range(), &res, &res);
let mut res_inner =
<Space as InnerProductSpace>::inner_product(&self.operator.range(), &res, &res);
let mut res_norm = res_inner.abs().sqrt();
let mut rel_res = res_norm / rhs_norm;

Expand All @@ -98,7 +99,7 @@ impl<'a, Space: InnerProductSpace, Op: AsApply<Domain = Space, Range = Space>>
}

for it_count in 0..self.max_iter {
let p_conj_inner = <Space as InnerProductSpace>::inner(
let p_conj_inner = <Space as InnerProductSpace>::inner_product(
&self.operator.range(),
&self.operator.apply(&p),
&p,
Expand All @@ -116,7 +117,8 @@ impl<'a, Space: InnerProductSpace, Op: AsApply<Domain = Space, Range = Space>>
callable(&self.x, &res);
}
let res_inner_previous = res_inner;
res_inner = <Space as InnerProductSpace>::inner(&self.operator.range(), &res, &res);
res_inner =
<Space as InnerProductSpace>::inner_product(&self.operator.range(), &res, &res);
res_norm = res_inner.abs().sqrt();
rel_res = res_norm / rhs_norm;
if res_norm < self.tol {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/operations/modified_gram_schmidt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl ModifiedGramSchmidt {
for elem_index in 0..nelements {
let mut elem = (frame.get(elem_index).unwrap()).clone();
for (other_index, other_elem) in frame.iter().take(elem_index).enumerate() {
let inner = space.inner(&elem, other_elem);
let inner = space.inner_product(&elem, other_elem);
*r_mat.get_mut([other_index, elem_index]).unwrap() = inner;
elem.axpy_inplace(-inner, other_elem);
}
Expand Down
14 changes: 7 additions & 7 deletions src/operator/space/element.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Elements of linear spaces
use std::rc::Rc;

use crate::dense::types::RlstScalar;
use num::One;

use super::LinearSpace;
use super::{InnerProductSpace, LinearSpace, NormedSpace};

/// An Element of a linear spaces.
pub trait Element: Clone {
Expand All @@ -20,7 +22,7 @@ pub trait Element: Clone {
Self: 'b;

/// Return the associated function space.
fn space(&self) -> &Self::Space;
fn space(&self) -> Rc<Self::Space>;

/// Get a view onto the element.
fn view(&self) -> Self::View<'_>;
Expand Down Expand Up @@ -95,12 +97,11 @@ pub trait Element: Clone {
/// Comppute the inner product with another vector
///
/// Only implemented for elements of inner product spaces.
fn inner(&self, other: &Self) -> Self::F
fn inner_product(&self, other: &Self) -> Self::F
where
Self::Space: super::InnerProductSpace,
{
// Weird way of writing it because rust-analyzer gets confused which inner to choose.
<Self::Space as super::InnerProductSpace>::inner(self.space(), self, other)
self.space().inner_product(self, other)
}

/// Compute the norm of a vector
Expand All @@ -110,8 +111,7 @@ pub trait Element: Clone {
where
Self::Space: super::NormedSpace,
{
// Rust-Analyzer is otherwise confused what `norm` method to choose.
<Self::Space as super::NormedSpace>::norm(self.space(), self)
self.space().norm(self)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/operator/space/inner_product_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ use super::LinearSpace;
/// Inner product space
pub trait InnerProductSpace: LinearSpace {
/// Inner product
fn inner(&self, x: &Self::E, other: &Self::E) -> Self::F;
fn inner_product(&self, x: &Self::E, other: &Self::E) -> Self::F;
}
2 changes: 1 addition & 1 deletion src/operator/space/normed_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub trait NormedSpace: LinearSpace {

impl<S: InnerProductSpace> NormedSpace for S {
fn norm(&self, x: &ElementType<Self>) -> <Self::F as RlstScalar>::Real {
let abs_square = self.inner(x, x).abs();
let abs_square = self.inner_product(x, x).abs();
abs_square.sqrt()
}
}
2 changes: 1 addition & 1 deletion tests/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn test_gram_schmidt() {
// Check orthogonality
for index1 in 0..3 {
for index2 in 0..3 {
let inner = space.inner(frame.get(index1).unwrap(), frame.get(index2).unwrap());
let inner = space.inner_product(frame.get(index1).unwrap(), frame.get(index2).unwrap());
if index1 == index2 {
approx::assert_relative_eq!(inner, c64::one(), epsilon = 1E-12);
} else {
Expand Down

0 comments on commit 87ec108

Please sign in to comment.