-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
augurs-dtw
crate with dynamic time warping implementation (…
…#98) Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
- Loading branch information
1 parent
edf6606
commit d7a3dbb
Showing
23 changed files
with
1,654 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use std::{fmt, ops::Index}; | ||
|
||
/// An error that can occur when creating a `DistanceMatrix`. | ||
#[derive(Debug)] | ||
pub enum DistanceMatrixError { | ||
/// The input matrix is not square. | ||
InvalidDistanceMatrix, | ||
} | ||
|
||
impl fmt::Display for DistanceMatrixError { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.write_str("invalid distance matrix") | ||
} | ||
} | ||
|
||
impl std::error::Error for DistanceMatrixError {} | ||
|
||
/// A matrix representing the distances between pairs of items. | ||
#[derive(Debug, Clone)] | ||
pub struct DistanceMatrix { | ||
matrix: Vec<Vec<f64>>, | ||
} | ||
|
||
impl DistanceMatrix { | ||
/// Create a new `DistanceMatrix` from a square matrix. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Returns an error if the input matrix is not square. | ||
pub fn try_from_square(matrix: Vec<Vec<f64>>) -> Result<Self, DistanceMatrixError> { | ||
if matrix.iter().all(|x| x.len() == matrix.len()) { | ||
Ok(Self { matrix }) | ||
} else { | ||
Err(DistanceMatrixError::InvalidDistanceMatrix) | ||
} | ||
} | ||
|
||
/// Consumes the `DistanceMatrix` and returns the inner matrix. | ||
pub fn into_inner(self) -> Vec<Vec<f64>> { | ||
self.matrix | ||
} | ||
|
||
/// Returns an iterator over the rows of the matrix. | ||
pub fn iter(&self) -> DistanceMatrixIter<'_> { | ||
DistanceMatrixIter { | ||
iter: self.matrix.iter(), | ||
} | ||
} | ||
|
||
/// Returns the shape of the matrix. | ||
/// | ||
/// The first element is the number of rows and the second element | ||
/// is the number of columns. | ||
/// | ||
/// The matrix is square, so the number of rows is equal to the number of columns | ||
/// and the number of input series. | ||
pub fn shape(&self) -> (usize, usize) { | ||
(self.matrix.len(), self.matrix.len()) | ||
} | ||
} | ||
|
||
impl Index<usize> for DistanceMatrix { | ||
type Output = [f64]; | ||
fn index(&self, index: usize) -> &Self::Output { | ||
&self.matrix[index] | ||
} | ||
} | ||
|
||
impl Index<(usize, usize)> for DistanceMatrix { | ||
type Output = f64; | ||
fn index(&self, (i, j): (usize, usize)) -> &Self::Output { | ||
&self.matrix[i][j] | ||
} | ||
} | ||
|
||
impl IntoIterator for DistanceMatrix { | ||
type Item = Vec<f64>; | ||
type IntoIter = std::vec::IntoIter<Self::Item>; | ||
fn into_iter(self) -> Self::IntoIter { | ||
self.matrix.into_iter() | ||
} | ||
} | ||
|
||
/// An iterator over the rows of a `DistanceMatrix`. | ||
#[derive(Debug)] | ||
pub struct DistanceMatrixIter<'a> { | ||
iter: std::slice::Iter<'a, Vec<f64>>, | ||
} | ||
|
||
impl<'a> Iterator for DistanceMatrixIter<'a> { | ||
type Item = &'a Vec<f64>; | ||
fn next(&mut self) -> Option<Self::Item> { | ||
self.iter.next() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Changelog | ||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | ||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
||
## [Unreleased] | ||
|
||
### Other | ||
- Add `augurs-dtw` crate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[package] | ||
name = "augurs-dtw" | ||
license.workspace = true | ||
authors.workspace = true | ||
documentation.workspace = true | ||
repository.workspace = true | ||
version.workspace = true | ||
edition.workspace = true | ||
keywords.workspace = true | ||
description = "Dynamic Time Warping (DTW) algorithm for Rust" | ||
|
||
[dependencies] | ||
augurs-core.workspace = true | ||
rayon = { version = "1.10.0", optional = true } | ||
|
||
[features] | ||
parallel = ["dep:rayon"] | ||
|
||
[dev-dependencies] | ||
criterion.workspace = true | ||
itertools.workspace = true | ||
|
||
[lib] | ||
bench = false | ||
|
||
[[bench]] | ||
name = "dtw" | ||
harness = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../LICENSE-APACHE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../LICENSE-MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Dynamic Time Warping (DTW) | ||
|
||
Implementation of the dynamic time warping (DTW) algorithm. | ||
|
||
DTW can be used to compare two sequences that may vary in time or speed. | ||
This implementation has built-in support for both Euclidean and Manhattan distance, | ||
and can be extended to support other distance functions by implementing the [`Distance`] | ||
trait and using the [`Dtw::new`] constructor. | ||
|
||
## Features | ||
|
||
- [x] DTW distance between two sequences | ||
- [x] optimized scalar implementation influenced by the [UCR Suite][ucr-suite] | ||
- [ ] SIMD optimized implementation | ||
- [ ] Z-normalization | ||
- [x] distance matrix calculations between N sequences | ||
- [x] parallelized distance matrix calculations | ||
- [ ] early stopping using `LB_Kim` (semi-implemented) | ||
- [ ] early stopping using `LB_Keogh` (semi-implemented) | ||
- [x] early stopping using the Euclidean upper bound | ||
|
||
Pull requests for missing features would be very welcome. | ||
|
||
## Usage | ||
|
||
``` | ||
use augurs_dtw::Dtw; | ||
let a = &[0.0, 1.0, 2.0]; | ||
let b = &[3.0, 4.0, 5.0]; | ||
let dist = Dtw::euclidean().distance(a, b); | ||
assert_eq!(dist, 5.0990195135927845); | ||
``` | ||
|
||
## Credits | ||
|
||
The algorithm is based on the code from the [UCR Suite][ucr-suite]. Benchmarks show similar | ||
or faster timings compared to [`dtaidistance`]'s C implementation, but note that `dtaidistance` is much more | ||
full featured! | ||
|
||
[ucr-suite]: https://www.cs.ucr.edu/~eamonn/UCRsuite.html | ||
[`dtaidistance`]: https://dtaidistance.readthedocs.io/ | ||
|
||
## License | ||
|
||
Dual-licensed to be compatible with the Rust project. | ||
Licensed under the Apache License, Version 2.0 `<http://www.apache.org/licenses/LICENSE-2.0>` or the MIT license `<http://opensource.org/licenses/MIT>`, at your option. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; | ||
use itertools::Itertools; | ||
|
||
use augurs_dtw::Dtw; | ||
|
||
fn examples() -> Vec<Vec<f64>> { | ||
let raw = include_str!("../data/series.csv"); | ||
let n_columns = raw.lines().next().unwrap().split(',').count(); | ||
let n_rows = raw.lines().count(); | ||
let mut examples = vec![Vec::with_capacity(n_rows); n_columns]; | ||
for line in raw.lines() { | ||
for (i, value) in line.split(',').enumerate() { | ||
let value: f64 = value.parse().unwrap(); | ||
if !value.is_nan() { | ||
examples[i].push(value); | ||
} | ||
} | ||
} | ||
examples | ||
} | ||
|
||
fn distance_euclidean(c: &mut Criterion) { | ||
let mut group = c.benchmark_group("distance_euclidean"); | ||
let examples = examples(); | ||
let (s, t) = (&examples[0], &examples[1]); | ||
let windows = [None, Some(2), Some(5), Some(10), Some(20), Some(50)]; | ||
for window in windows { | ||
group.bench_with_input( | ||
BenchmarkId::from_parameter(format!("{:?}", window)), | ||
&(s, t), | ||
|b, (s, t)| { | ||
b.iter(|| { | ||
let mut dtw = Dtw::euclidean(); | ||
if let Some(window) = window { | ||
dtw = dtw.with_window(window); | ||
} | ||
dtw.distance(s, t) | ||
}); | ||
}, | ||
); | ||
} | ||
} | ||
|
||
fn distance_matrix_euclidean(c: &mut Criterion) { | ||
let mut group = c.benchmark_group("distance_matrix_euclidean"); | ||
let examples = examples(); | ||
let examples = examples.iter().map(|v| v.as_slice()).collect::<Vec<_>>(); | ||
let windows = [Some(2), Some(10)]; | ||
let parallelize = [false, true]; | ||
for (window, parallelize) in windows.into_iter().cartesian_product(parallelize) { | ||
group.bench_with_input( | ||
BenchmarkId::from_parameter(format!( | ||
"window: {:?}, parallelize: {:?}", | ||
window, parallelize | ||
)), | ||
&examples, | ||
|b, examples| { | ||
b.iter(|| { | ||
let mut dtw = Dtw::euclidean().parallelize(parallelize); | ||
if let Some(window) = window { | ||
dtw = dtw.with_window(window).with_max_distance(window as f64); | ||
} | ||
dtw.distance_matrix(examples) | ||
}); | ||
}, | ||
); | ||
} | ||
} | ||
|
||
criterion_group!(benches, distance_euclidean, distance_matrix_euclidean); | ||
criterion_main!(benches); |
Oops, something went wrong.