diff --git a/Cargo.lock b/Cargo.lock index 1791b889..4c380fab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1429,7 +1429,7 @@ dependencies = [ [[package]] name = "nshare" version = "0.9.0" -source = "git+https://github.com/benjaminrwilson/nshare#e896e103dc6b6e449b760d1e6d639590a864a80b" +source = "git+https://github.com/benjaminrwilson/nshare?rev=e896e103dc6b6e449b760d1e6d639590a864a80b#e896e103dc6b6e449b760d1e6d639590a864a80b" dependencies = [ "image", "nalgebra", diff --git a/rust/src/bin/build_accumulated_sweeps.rs b/rust/src/bin/build_accumulated_sweeps.rs index f68ae1eb..c4488eb8 100644 --- a/rust/src/bin/build_accumulated_sweeps.rs +++ b/rust/src/bin/build_accumulated_sweeps.rs @@ -35,8 +35,7 @@ const MEMORY_MAPPED: bool = false; static DST_DATASET_NAME: Lazy = Lazy::new(|| format!("{DATASET_NAME}_{NUM_ACCUMULATED_SWEEPS}_sweep")); -static SRC_PREFIX: Lazy = - Lazy::new(|| ROOT_DIR.join(DATASET_NAME.clone()).join(DATASET_TYPE)); +static SRC_PREFIX: Lazy = Lazy::new(|| ROOT_DIR.join(DATASET_NAME).join(DATASET_TYPE)); static DST_PREFIX: Lazy = Lazy::new(|| ROOT_DIR.join(DST_DATASET_NAME.clone()).join(DATASET_TYPE)); diff --git a/rust/src/geometry/augmentations.rs b/rust/src/geometry/augmentations.rs new file mode 100644 index 00000000..30297422 --- /dev/null +++ b/rust/src/geometry/augmentations.rs @@ -0,0 +1,23 @@ +//! # augmentations +//! +//! Geometric augmentations. + +use std::f32::consts::PI; + +use ndarray::{Array, ArrayView, Ix2}; + +use crate::geometry::so3::{quat_to_yaw, yaw_to_quat}; + +/// Reflect pose across the x-axis. +pub fn reflect_pose_x(quat_wxyz: &ArrayView) -> Array { + let yaw_rad = quat_to_yaw(quat_wxyz); + let reflected_yaw_rad = -yaw_rad; + yaw_to_quat(&reflected_yaw_rad.view()) +} + +/// Reflect pose across the y-axis. +pub fn reflect_pose_y(quat_wxyz: &ArrayView) -> Array { + let yaw_rad = quat_to_yaw(quat_wxyz); + let reflected_yaw_rad = PI - yaw_rad; + yaw_to_quat(&reflected_yaw_rad.view()) +} diff --git a/rust/src/geometry/camera/pinhole_camera.rs b/rust/src/geometry/camera/pinhole_camera.rs index 74e5e089..f0e31325 100644 --- a/rust/src/geometry/camera/pinhole_camera.rs +++ b/rust/src/geometry/camera/pinhole_camera.rs @@ -6,7 +6,10 @@ use polars::{ prelude::{DataFrame, IntoLazy}, }; -use crate::{geometry::utils::cart_to_hom, io::read_feather_eager, se3::SE3, so3::quat_to_mat3}; +use crate::{ + geometry::se3::SE3, geometry::so3::quat_to_mat3, geometry::utils::cart_to_hom, + io::read_feather_eager, +}; /// Pinhole camera intrinsics. #[derive(Clone, Debug)] diff --git a/rust/src/geometry/mod.rs b/rust/src/geometry/mod.rs index d9627158..d6c864a2 100644 --- a/rust/src/geometry/mod.rs +++ b/rust/src/geometry/mod.rs @@ -2,6 +2,13 @@ //! //! Geometric operations for data processing. +/// Geometric augmentations. +pub mod augmentations; /// Camera models. pub mod camera; +/// Special Euclidean Group 3. +pub mod se3; +/// Special Orthogonal Group 3. +pub mod so3; +/// Geometric utility functions. pub mod utils; diff --git a/rust/src/se3.rs b/rust/src/geometry/se3.rs similarity index 100% rename from rust/src/se3.rs rename to rust/src/geometry/se3.rs diff --git a/rust/src/geometry/so3.rs b/rust/src/geometry/so3.rs new file mode 100644 index 00000000..c4f52b96 --- /dev/null +++ b/rust/src/geometry/so3.rs @@ -0,0 +1,73 @@ +//! # SO(3) +//! +//! Special Orthogonal Group 3 (SO(3)). + +use ndarray::{par_azip, Array, Array2, ArrayView, Ix1, Ix2}; + +/// Convert a quaternion in scalar-first format to a 3x3 rotation matrix. +pub fn quat_to_mat3(quat_wxyz: &ArrayView) -> Array { + let w = quat_wxyz[0]; + let x = quat_wxyz[1]; + let y = quat_wxyz[2]; + let z = quat_wxyz[3]; + + let e_00 = 1. - 2. * y.powi(2) - 2. * z.powi(2); + let e_01: f32 = 2. * x * y - 2. * z * w; + let e_02: f32 = 2. * x * z + 2. * y * w; + + let e_10 = 2. * x * y + 2. * z * w; + let e_11 = 1. - 2. * x.powi(2) - 2. * z.powi(2); + let e_12 = 2. * y * z - 2. * x * w; + + let e_20 = 2. * x * z - 2. * y * w; + let e_21 = 2. * y * z + 2. * x * w; + let e_22 = 1. - 2. * x.powi(2) - 2. * y.powi(2); + + // Safety: We will always have nine elements. + unsafe { + Array2::from_shape_vec_unchecked( + [3, 3], + vec![e_00, e_01, e_02, e_10, e_11, e_12, e_20, e_21, e_22], + ) + } +} + +/// Convert a scalar-first quaternion to yaw. +/// In the Argoverse 2 coordinate system, this is counter-clockwise rotation about the +z axis. +/// Parallelized for batch processing. +pub fn quat_to_yaw(quat_wxyz: &ArrayView) -> Array { + let num_quats = quat_wxyz.shape()[0]; + let mut yaws_rad = Array::::zeros((num_quats, 1)); + par_azip!((mut y in yaws_rad.outer_iter_mut(), q in quat_wxyz.outer_iter()) { + y[0] = _quat_to_yaw(&q); + }); + yaws_rad +} + +/// Convert a scalar-first quaternion to yaw. +/// In the Argoverse 2 coordinate system, this is counter-clockwise rotation about the +z axis. +pub fn _quat_to_yaw(quat_wxyz: &ArrayView) -> f32 { + let (qw, qx, qy, qz) = (quat_wxyz[0], quat_wxyz[1], quat_wxyz[2], quat_wxyz[3]); + let siny_cosp = 2. * (qw * qz + qx * qy); + let cosy_cosp = 1. - 2. * (qy * qy + qz * qz); + siny_cosp.atan2(cosy_cosp) +} + +/// Convert a scalar-first quaternion to yaw. +/// In the Argoverse 2 coordinate system, this is counter-clockwise rotation about the +z axis. +/// Parallelized for batch processing. +pub fn yaw_to_quat(yaw_rad: &ArrayView) -> Array { + let num_yaws = yaw_rad.shape()[0]; + let mut quat_wxyz = Array::::zeros((num_yaws, 4)); + par_azip!((mut q in quat_wxyz.outer_iter_mut(), y in yaw_rad.outer_iter()) { + q.assign(&_yaw_to_quat(y[0])); + }); + quat_wxyz +} + +/// Convert rotation about the z-axis to a scalar-first quaternion. +pub fn _yaw_to_quat(yaw_rad: f32) -> Array { + let qw = f32::cos(0.5 * yaw_rad); + let qz = f32::sin(0.5 * yaw_rad); + Array::::from_vec(vec![qw, 0.0, 0.0, qz]) +} diff --git a/rust/src/io.rs b/rust/src/io.rs index 9169abd3..d7a6672d 100644 --- a/rust/src/io.rs +++ b/rust/src/io.rs @@ -29,10 +29,10 @@ use std::fs::File; use std::path::PathBuf; use crate::constants::POSE_COLUMNS; -use crate::se3::SE3; +use crate::geometry::se3::SE3; use image::io::Reader as ImageReader; -use crate::so3::quat_to_mat3; +use crate::geometry::so3::quat_to_mat3; /// Read a feather file and load into a `polars` dataframe. pub fn read_feather_eager(path: &PathBuf, memory_mapped: bool) -> DataFrame { diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 07f67699..0e6b0f5c 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -11,18 +11,18 @@ pub mod geometry; pub mod io; pub mod ops; pub mod path; -pub mod se3; -pub mod so3; +pub mod share; pub mod structures; use data_loader::{DataLoader, Sweep}; +use geometry::augmentations::{reflect_pose_x, reflect_pose_y}; use ndarray::{Dim, Ix1, Ix2}; use numpy::PyReadonlyArray; use numpy::{IntoPyArray, PyArray}; use pyo3::prelude::*; +use geometry::so3::{quat_to_mat3, quat_to_yaw, yaw_to_quat}; use numpy::PyReadonlyArray2; -use so3::quat_to_mat3; use crate::ops::voxelize; @@ -65,12 +65,56 @@ fn py_quat_to_mat3<'py>( quat_to_mat3(&quat_wxyz.as_array().view()).into_pyarray(py) } +#[pyfunction] +#[pyo3(name = "quat_to_yaw")] +#[allow(clippy::type_complexity)] +fn py_quat_to_yaw<'py>( + py: Python<'py>, + quat_wxyz: PyReadonlyArray, +) -> &'py PyArray { + quat_to_yaw(&quat_wxyz.as_array().view()).into_pyarray(py) +} + +#[pyfunction] +#[pyo3(name = "yaw_to_quat")] +#[allow(clippy::type_complexity)] +fn py_yaw_to_quat<'py>( + py: Python<'py>, + quat_wxyz: PyReadonlyArray, +) -> &'py PyArray { + yaw_to_quat(&quat_wxyz.as_array().view()).into_pyarray(py) +} + +#[pyfunction] +#[pyo3(name = "reflect_pose_x")] +#[allow(clippy::type_complexity)] +fn py_reflect_pose_x<'py>( + py: Python<'py>, + quat_wxyz: PyReadonlyArray, +) -> &'py PyArray { + reflect_pose_x(&quat_wxyz.as_array().view()).into_pyarray(py) +} + +#[pyfunction] +#[pyo3(name = "reflect_pose_y")] +#[allow(clippy::type_complexity)] +fn py_reflect_pose_y<'py>( + py: Python<'py>, + quat_wxyz: PyReadonlyArray, +) -> &'py PyArray { + reflect_pose_y(&quat_wxyz.as_array().view()).into_pyarray(py) +} + /// A Python module implemented in Rust. #[pymodule] fn _r(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(py_quat_to_mat3, m)?)?; + m.add_function(wrap_pyfunction!(py_quat_to_yaw, m)?)?; + m.add_function(wrap_pyfunction!(py_reflect_pose_x, m)?)?; + m.add_function(wrap_pyfunction!(py_reflect_pose_y, m)?)?; m.add_function(wrap_pyfunction!(py_voxelize, m)?)?; + m.add_function(wrap_pyfunction!(py_yaw_to_quat, m)?)?; Ok(()) } diff --git a/rust/src/share.rs b/rust/src/share.rs new file mode 100644 index 00000000..db5cf66a --- /dev/null +++ b/rust/src/share.rs @@ -0,0 +1,24 @@ +//! # share +//! +//! Conversion methods between different libraries. + +use ndarray::{Array, Ix2}; +use polars::{prelude::NamedFrom, series::Series}; + +/// Convert the columns of an `ndarray::Array` into a vector of `polars::series::Series`. +pub fn ndarray_to_series_vec(arr: Array, column_names: Vec<&str>) -> Vec { + let num_dims = arr.shape()[1]; + if num_dims != column_names.len() { + panic!("Number of array columns and column names must match."); + } + + let mut series_vec = vec![]; + for (column, column_name) in arr.columns().into_iter().zip(column_names) { + let series = Series::new( + column_name, + column.as_standard_layout().to_owned().into_raw_vec(), + ); + series_vec.push(series); + } + series_vec +} diff --git a/rust/src/so3.rs b/rust/src/so3.rs deleted file mode 100644 index 5ae11546..00000000 --- a/rust/src/so3.rs +++ /dev/null @@ -1,33 +0,0 @@ -//! # SO(3) -//! -//! Special Orthogonal Group 3 (SO(3)). - -use ndarray::{Array2, ArrayView1}; - -/// Convert a quaternion in scalar-first format to a 3x3 rotation matrix. -pub fn quat_to_mat3(quat_wxyz: &ArrayView1) -> Array2 { - let w = quat_wxyz[0]; - let x = quat_wxyz[1]; - let y = quat_wxyz[2]; - let z = quat_wxyz[3]; - - let e_00 = 1. - 2. * y.powi(2) - 2. * z.powi(2); - let e_01: f32 = 2. * x * y - 2. * z * w; - let e_02: f32 = 2. * x * z + 2. * y * w; - - let e_10 = 2. * x * y + 2. * z * w; - let e_11 = 1. - 2. * x.powi(2) - 2. * z.powi(2); - let e_12 = 2. * y * z - 2. * x * w; - - let e_20 = 2. * x * z - 2. * y * w; - let e_21 = 2. * y * z + 2. * x * w; - let e_22 = 1. - 2. * x.powi(2) - 2. * y.powi(2); - - // Safety: We will always have nine elements. - unsafe { - Array2::from_shape_vec_unchecked( - [3, 3], - vec![e_00, e_01, e_02, e_10, e_11, e_12, e_20, e_21, e_22], - ) - } -}