Skip to content

Commit

Permalink
Fix quat_to_mat3. (#172)
Browse files Browse the repository at this point in the history
* Fix quat_to_mat3.

* Add python bindings.
  • Loading branch information
benjaminrwilson authored Apr 28, 2023
1 parent 98eb457 commit df8146f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
15 changes: 14 additions & 1 deletion rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ pub mod se3;
pub mod so3;

use data_loader::{DataLoader, Sweep};
use ndarray::Dim;
use ndarray::{Dim, Ix1, Ix2};
use numpy::PyReadonlyArray;
use numpy::{IntoPyArray, PyArray};
use pyo3::prelude::*;

use numpy::PyReadonlyArray2;
use so3::quat_to_mat3;

use crate::ops::voxelize;

Expand Down Expand Up @@ -51,11 +53,22 @@ fn py_voxelize<'py>(
)
}

#[pyfunction]
#[pyo3(name = "quat_to_mat3")]
#[allow(clippy::type_complexity)]
fn py_quat_to_mat3<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix1>,
) -> &'py PyArray<f32, Ix2> {
quat_to_mat3(&quat_wxyz.as_array().view()).into_pyarray(py)
}

/// A Python module implemented in Rust.
#[pymodule]
fn _r(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<DataLoader>()?;
m.add_class::<Sweep>()?;
m.add_function(wrap_pyfunction!(py_quat_to_mat3, m)?)?;
m.add_function(wrap_pyfunction!(py_voxelize, m)?)?;
Ok(())
}
31 changes: 10 additions & 21 deletions rust/src/so3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,17 @@ pub fn quat_to_mat3(quat_wxyz: &ArrayView1<f32>) -> Array2<f32> {
let y = quat_wxyz[2];
let z = quat_wxyz[3];

let tx = 2. * x;
let ty = 2. * y;
let tz = 2. * z;
let twx = tx * w;
let twy = ty * w;
let twz = tz * w;
let txx = tx * x;
let txy = ty * x;
let txz = tz * x;
let tyy = ty * y;
let tyz = tz * y;
let tzz = tz * z;
let e_00 = 1. - 2. * y.powi(2) - 2. * z.powi(2);
let e_01: f32 = 2. * x * y - 2. * z * w;
let e_02: f32 = 2. * x * z + 2. * y * w;

let e_00 = 1. - (tyy + tzz);
let e_01 = txy - twz;
let e_02 = txy + twy;
let e_10 = txy + twz;
let e_11 = 1. - (txx + tzz);
let e_12 = tyz - twx;
let e_20 = txz - twy;
let e_21 = tyz + twx;
let e_22 = 1. - (txx + tyy);
let e_10 = 2. * x * y + 2. * z * w;
let e_11 = 1. - 2. * x.powi(2) - 2. * z.powi(2);
let e_12 = 2. * y * z - 2. * x * w;

let e_20 = 2. * x * z - 2. * y * w;
let e_21 = 2. * y * z + 2. * x * w;
let e_22 = 1. - 2. * x.powi(2) - 2. * y.powi(2);

// Safety: We will always have nine elements.
unsafe {
Expand Down

0 comments on commit df8146f

Please sign in to comment.