Skip to content

Commit

Permalink
Correctly serialize X (regressors) as nested sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k committed Oct 10, 2024
1 parent 032b508 commit 59b5547
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 3 deletions.
139 changes: 138 additions & 1 deletion crates/augurs-prophet/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ impl<'de> serde::Deserialize<'de> for TrendIndicator {

/// Data for the Prophet model.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[allow(non_snake_case)]
pub struct Data {
/// Number of time periods.
Expand All @@ -110,6 +109,13 @@ pub struct Data {
/// Indicator of multiplicative features, length k.
pub s_m: Vec<i32>,
/// Regressors, shape (n, k).
///
/// This is stored as a `Vec<f64>` rather than a nested `Vec<Vec<f64>>`
/// because passing such a struct by reference is tricky in Rust, since
/// it can't be dereferenced to a `&[&[f64]]` (which would be ideal).
///
/// However, when serialized to JSON, it is converted to a nested array
/// of arrays, which is what cmdstan expects.
pub X: Vec<f64>,
/// Scale on seasonality prior.
pub sigmas: Vec<PositiveFloat>,
Expand All @@ -118,6 +124,50 @@ pub struct Data {
pub tau: PositiveFloat,
}

#[cfg(feature = "serde")]
impl serde::Serialize for Data {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::{SerializeSeq, SerializeStruct};

/// A serializer which serializes X, a flat slice of f64s, as an sequence of sequences,
/// with each one having length equal to the second field.
struct XSerializer<'a>(&'a [f64], usize);

impl<'a> serde::Serialize for XSerializer<'a> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut outer = serializer.serialize_seq(Some(self.0.len() / self.1))?;
for chunk in self.0.chunks(self.1) {
outer.serialize_element(&chunk)?;
}
outer.end()
}
}

let mut s = serializer.serialize_struct("Data", 13)?;
let x = XSerializer(&self.X, self.K as usize);
s.serialize_field("T", &self.T)?;
s.serialize_field("y", &self.y)?;
s.serialize_field("t", &self.t)?;
s.serialize_field("cap", &self.cap)?;
s.serialize_field("S", &self.S)?;
s.serialize_field("t_change", &self.t_change)?;
s.serialize_field("trend_indicator", &self.trend_indicator)?;
s.serialize_field("K", &self.K)?;
s.serialize_field("s_a", &self.s_a)?;
s.serialize_field("s_m", &self.s_m)?;
s.serialize_field("X", &x)?;
s.serialize_field("sigmas", &self.sigmas)?;
s.serialize_field("tau", &self.tau)?;
s.end()
}
}

/// The algorithm to use for optimization. One of: 'BFGS', 'LBFGS', 'Newton'.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Algorithm {
Expand Down Expand Up @@ -303,3 +353,90 @@ pub mod mock_optimizer {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialize_data() {
let data = Data {
T: 3,
y: vec![1.0, 2.0, 3.0],
t: vec![0.0, 1.0, 2.0],
X: vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
sigmas: vec![
1.0.try_into().unwrap(),
2.0.try_into().unwrap(),
3.0.try_into().unwrap(),
],
tau: 1.0.try_into().unwrap(),
K: 2,
s_a: vec![1, 1, 1],
s_m: vec![0, 0, 0],
cap: vec![0.0, 0.0, 0.0],
S: 2,
t_change: vec![0.0, 0.0, 0.0],
trend_indicator: TrendIndicator::Linear,
};
let serialized = serde_json::to_string_pretty(&data).unwrap();
pretty_assertions::assert_eq!(
serialized,
r#"{
"T": 3,
"y": [
1.0,
2.0,
3.0
],
"t": [
0.0,
1.0,
2.0
],
"cap": [
0.0,
0.0,
0.0
],
"S": 2,
"t_change": [
0.0,
0.0,
0.0
],
"trend_indicator": 0,
"K": 2,
"s_a": [
1,
1,
1
],
"s_m": [
0,
0,
0
],
"X": [
[
1.0,
2.0
],
[
3.0,
1.0
],
[
2.0,
3.0
]
],
"sigmas": [
1.0,
2.0,
3.0
],
"tau": 1.0
}"#
);
}
}
23 changes: 21 additions & 2 deletions examples/forecasting/examples/prophet_cmdstan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
//! $ cargo run --features download --bin download-stan-model
//! $ cargo run --example prophet_cmdstan
//! ```
use augurs::prophet::{cmdstan::CmdstanOptimizer, Prophet, TrainingData};
use std::collections::HashMap;

use augurs::prophet::{cmdstan::CmdstanOptimizer, Prophet, Regressor, TrainingData};

fn main() -> Result<(), Box<dyn std::error::Error>> {
let ds = vec![
Expand All @@ -19,13 +21,30 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let y = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
];
let data = TrainingData::new(ds, y.clone())?;
let data = TrainingData::new(ds, y.clone())?
.with_regressors(HashMap::from([
(
"foo".to_string(),
vec![
1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0,
],
),
(
"bar".to_string(),
vec![
4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 4.0,
],
),
]))
.unwrap();

let cmdstan = CmdstanOptimizer::with_prophet_path("prophet_stan_model/prophet_model.bin")?;
// If you were using the embedded version of the cmdstan model, you'd use this:
// let cmdstan = CmdstanOptimizer::new_embedded();

let mut prophet = Prophet::new(Default::default(), cmdstan);
prophet.add_regressor("foo".to_string(), Regressor::additive());
prophet.add_regressor("bar".to_string(), Regressor::additive());

prophet.fit(data, Default::default())?;
let predictions = prophet.predict(None)?;
Expand Down

0 comments on commit 59b5547

Please sign in to comment.