diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 49426e9b..084cfe87 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -14,11 +14,13 @@ pub mod se3; pub mod so3; use data_loader::{DataLoader, Sweep}; -use ndarray::Dim; +use ndarray::{Dim, Ix1, Ix2}; +use numpy::PyReadonlyArray; use numpy::{IntoPyArray, PyArray}; use pyo3::prelude::*; use numpy::PyReadonlyArray2; +use so3::quat_to_mat3; use crate::ops::voxelize; @@ -51,11 +53,22 @@ fn py_voxelize<'py>( ) } +#[pyfunction] +#[pyo3(name = "quat_to_mat3")] +#[allow(clippy::type_complexity)] +fn py_quat_to_mat3<'py>( + py: Python<'py>, + quat_wxyz: PyReadonlyArray, +) -> &'py PyArray { + quat_to_mat3(&quat_wxyz.as_array().view()).into_pyarray(py) +} + /// A Python module implemented in Rust. #[pymodule] fn _r(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(py_quat_to_mat3, m)?)?; m.add_function(wrap_pyfunction!(py_voxelize, m)?)?; Ok(()) } diff --git a/rust/src/so3.rs b/rust/src/so3.rs index b964578d..5ae11546 100644 --- a/rust/src/so3.rs +++ b/rust/src/so3.rs @@ -11,28 +11,17 @@ pub fn quat_to_mat3(quat_wxyz: &ArrayView1) -> Array2 { let y = quat_wxyz[2]; let z = quat_wxyz[3]; - let tx = 2. * x; - let ty = 2. * y; - let tz = 2. * z; - let twx = tx * w; - let twy = ty * w; - let twz = tz * w; - let txx = tx * x; - let txy = ty * x; - let txz = tz * x; - let tyy = ty * y; - let tyz = tz * y; - let tzz = tz * z; + let e_00 = 1. - 2. * y.powi(2) - 2. * z.powi(2); + let e_01: f32 = 2. * x * y - 2. * z * w; + let e_02: f32 = 2. * x * z + 2. * y * w; - let e_00 = 1. - (tyy + tzz); - let e_01 = txy - twz; - let e_02 = txy + twy; - let e_10 = txy + twz; - let e_11 = 1. - (txx + tzz); - let e_12 = tyz - twx; - let e_20 = txz - twy; - let e_21 = tyz + twx; - let e_22 = 1. - (txx + tyy); + let e_10 = 2. * x * y + 2. * z * w; + let e_11 = 1. - 2. * x.powi(2) - 2. * z.powi(2); + let e_12 = 2. * y * z - 2. * x * w; + + let e_20 = 2. * x * z - 2. * y * w; + let e_21 = 2. * y * z + 2. * x * w; + let e_22 = 1. - 2. * x.powi(2) - 2. * y.powi(2); // Safety: We will always have nine elements. unsafe {