Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cmdstan-based optimizer for augurs-prophet #121

Merged
merged 10 commits into from
Oct 10, 2024
6 changes: 5 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ jobs:
with:
bins: cargo-nextest,just

- name: Download Prophet Stan model
# Download the Prophet Stan model since an example requires it.
run: just download-prophet-stan-model

- name: Run cargo nextest
run: just test
run: just test-all
- name: Run doc tests
run: just doctest

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
/Cargo.lock
.bacon-locations
.vscode
prophet_stan_model
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ rand = "0.8.5"
roots = "0.0.8"
serde = { version = "1.0.166", features = ["derive"] }
statrs = "0.17.1"
serde_json = "1.0.128"
thiserror = "1.0.40"
tinyvec = "1.6.0"
tracing = "0.1.37"
Expand Down
2 changes: 1 addition & 1 deletion crates/augurs-outlier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tracing.workspace = true

[dev-dependencies]
augurs.workspace = true
serde_json = "1.0.128"
serde_json.workspace = true

[features]
parallel = ["rayon"]
2 changes: 2 additions & 0 deletions crates/augurs-prophet/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/build
/prophet_stan_model
23 changes: 23 additions & 0 deletions crates/augurs-prophet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,35 @@ itertools.workspace = true
num-traits.workspace = true
rand.workspace = true
statrs.workspace = true
serde = { workspace = true, optional = true, features = ["derive"] }
serde_json = { workspace = true, optional = true }
tempfile = { version = "3.13.0", optional = true }
thiserror.workspace = true
ureq = { version = "2.10.1", optional = true }
zip = { version = "2.2.0", optional = true }

[dev-dependencies]
augurs.workspace = true
augurs-testing.workspace = true
chrono.workspace = true
pretty_assertions.workspace = true

[build-dependencies]
tempfile = { version = "3.13.0", optional = true }

[features]
bytemuck = ["dep:bytemuck"]
cmdstan = ["dep:tempfile", "dep:serde_json", "serde"]
compile-cmdstan = ["cmdstan", "dep:tempfile"]
download = ["dep:ureq", "dep:zip"]
# Ignore cmdstan compilation in the build script.
# This should only be used for developing the library, not by
# end users, or you may end up with a broken build where the
# Prophet model isn't available to be compiled into the binary.
internal-ignore-cmdstan-failure = []
serde = ["dep:serde"]

[[bin]]
name = "download-stan-model"
path = "src/bin/main.rs"
required-features = ["download"]
sd2k marked this conversation as resolved.
Show resolved Hide resolved
58 changes: 55 additions & 3 deletions crates/augurs-prophet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,46 @@
`augurs-prophet` contains an implementation of the [Prophet]
time series forecasting library.

## Example

First, download the Prophet Stan model using the included binary:

```sh
$ cargo install --bin download-stan-model --features download augurs-prophet
$ download-stan-model
```

```rust,no_run
use augurs::prophet::{cmdstan::CmdstanOptimizer, Prophet, TrainingData};

fn main() -> Result<(), Box<dyn std::error::Error>> {
let ds = vec![
1704067200, 1704871384, 1705675569, 1706479753, 1707283938, 1708088123, 1708892307,
1709696492, 1710500676, 1711304861, 1712109046, 1712913230, 1713717415,
];
let y = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
];
let data = TrainingData::new(ds, y.clone())?;

let cmdstan = CmdstanOptimizer::with_prophet_path("prophet_stan_model/prophet_model.bin")?;
// If you were using the embedded version of the cmdstan model, you'd enable
// the `compile-cmdstan` feature and use this:
//
// let cmdstan = CmdstanOptimizer::new_embedded();

let mut prophet = Prophet::new(Default::default(), cmdstan);

prophet.fit(data, Default::default())?;
let predictions = prophet.predict(None)?;
assert_eq!(predictions.yhat.point.len(), y.len());
assert!(predictions.yhat.lower.is_some());
assert!(predictions.yhat.upper.is_some());
println!("Predicted values: {:#?}", predictions.yhat);
Ok(())
}
```

This crate aims to be low-dependency to enable it to run in as
many places as possible. With that said, we need to talk about
optimizers…
Expand All @@ -15,9 +55,13 @@ inference as well as maximum likelihood estimation using optimizers such as L-BF
However, it is written in C++ and has non-trivial dependencies, which makes it
difficult to interface with from Rust (or, indeed, Python).

`augurs-prophet` (similar to the Python library) abstracts optimization
and sampling implementations using the `Optimizer` and `Sampler` traits.
These are yet to be implemented, but I have a few ideas:
Similar to the Python library, `augurs-prophet` abstracts MLE optimization
using the `Optimizer` and (later) MCMC using the `Sampler` traits.
There is a single implementation of the `Optimizer` which uses
`cmdstan` to run the optimization. See below and the `cmdstan` module
for details.

Further implementations are possible, with some ideas below.

### `cmdstan`

Expand All @@ -31,6 +75,14 @@ This works fine if you're operating in a desktop or server environment,
but poses issues when running in more esoteric environments such as
WebAssembly.

The `cmdstan` module of this crate contains an implementation of `Optimizer`
which will use a compiled Stan program to do this. See the `cmdstan` module
for more details on how to use it.

This requires the `cmdstan` feature to be enabled, and optionally the
`compile-cmdstan` feature to be enabled if you want to compile and embed
the Stan model at build time.

### `libstan`

We could choose to write a `libstan` crate which uses [`cxx`][cxx] to
Expand Down
88 changes: 88 additions & 0 deletions crates/augurs-prophet/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/// Compile the Prophet model (in prophet.stan) to a binary,
/// using the Makefile in the cmdstan installation.
///
/// This requires:
/// - The `STAN_PATH` environment variable to be set to the
/// path to the Stan installation.
#[cfg(all(feature = "cmdstan", feature = "compile-cmdstan"))]
fn compile_model() -> Result<(), Box<dyn std::error::Error>> {
use std::{fs, path::PathBuf, process::Command};
use tempfile::TempDir;

println!("cargo::rerun-if-changed=prophet.stan");
println!("cargo::rerun-if-env-changed=STAN_PATH");

let stan_path: PathBuf = std::env::var("STAN_PATH")
.map_err(|_| "STAN_PATH not set")?
.parse()
.map_err(|_| "invalid STAN_PATH")?;
sd2k marked this conversation as resolved.
Show resolved Hide resolved
let cmdstan_bin_path = stan_path.join("bin/cmdstan");
let model_stan = include_bytes!("./prophet.stan");
sd2k marked this conversation as resolved.
Show resolved Hide resolved

let build_dir = PathBuf::from(std::env::var("OUT_DIR")?);
fs::create_dir_all(&build_dir).map_err(|_| "could not create build directory")?;

// Write the Prophet Stan file to a named file in a temporary directory.
let tmp_dir = TempDir::new()?;
let prophet_stan_path = tmp_dir.path().join("prophet.stan");
eprintln!("Writing Prophet model to {}", prophet_stan_path.display());
fs::write(tmp_dir.path().join("prophet.stan"), model_stan)?;

// The Stan Makefile expects to see the path to the final executable
// file (without the .stan extension). It will build the executable
// at this location.
let tmp_exe_path = prophet_stan_path.with_extension("");

// Execute the cmdstan make command pointing at the expected
// prophet file.
eprintln!("Compiling Prophet model to {}", tmp_exe_path.display());
let mut cmd = Command::new("make");
cmd.current_dir(cmdstan_bin_path).arg(&tmp_exe_path);
eprintln!("Executing {:?}", cmd);
let output = cmd.output()?;
if !output.status.success() {
return Err(format!("make failed: {}", String::from_utf8_lossy(&output.stderr)).into());
}
eprintln!("Successfully compiled Prophet model");

// Copy the executable to the final location.
let dest_exe_path = build_dir.join("prophet");
std::fs::copy(tmp_exe_path, &dest_exe_path)?;
eprintln!("Copied prophet exe to {}", dest_exe_path.display());

// Copy libtbb to the final location.
let libtbb_path = stan_path.join("lib/libtbb.so.12");
let dest_libtbb_path = build_dir.join("libtbb.so.12");
std::fs::copy(&libtbb_path, &dest_libtbb_path)?;
Comment on lines +54 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle Dynamic Library Version Flexibly

Hardcoding the library name libtbb.so.12 may lead to compatibility issues if the version changes.

Consider modifying the code to dynamically find the libtbb library:

  • Option 1: Use glob patterns to match any version of the shared library.

    use glob::glob;
    
    let libtbb_pattern = stan_path.join("lib/libtbb.so.*");
    let libtbb_path = glob(&libtbb_pattern.to_string_lossy())?
        .filter_map(Result::ok)
        .next()
        .ok_or("libtbb.so not found")?;
  • Option 2: Allow the library name to be configurable or document the required version explicitly.

This ensures that your build script adapts to different versions of the libtbb library.

eprintln!(
"Copied libtbb.so from {} to {}",
libtbb_path.display(),
dest_libtbb_path.display(),
);

Ok(())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let _result = Ok::<(), &'static str>(());
#[cfg(all(feature = "cmdstan", feature = "compile-cmdstan"))]
let _result = compile_model();
// This is a complete hack but lets us get away with still using
// the `--all-features` flag of Cargo without everything failing
// if there isn't a Stan installation.
// Basically, if have this feature enabled, skip any failures in
// the build process and just create some empty files.
// This will cause things to fail at runtime if there isn't a Stan
// installation, but that's okay because no-one should ever use this
// feature.
#[cfg(feature = "internal-ignore-cmdstan-failure")]
if _result.is_err() {
let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?);
std::fs::create_dir_all(&out_dir)?;
std::fs::File::create(out_dir.join("prophet"))?;
std::fs::File::create(out_dir.join("libtbb.so.12"))?;
}
#[cfg(not(feature = "internal-ignore-cmdstan-failure"))]
_result?;
Ok(())
}
Empty file.
144 changes: 144 additions & 0 deletions crates/augurs-prophet/prophet.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) Facebook, Inc. and its affiliates.

// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

functions {
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
// Assumes t and t_change are sorted.
matrix[T, S] A;
row_vector[S] a_row;
int cp_idx;

// Start with an empty matrix.
A = rep_matrix(0, T, S);
a_row = rep_row_vector(0, S);
cp_idx = 1;

// Fill in each row of A.
for (i in 1:T) {
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
a_row[cp_idx] = 1;
cp_idx = cp_idx + 1;
}
A[i] = a_row;
}
return A;
}

// Logistic trend functions

vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
vector[S] gamma; // adjusted offsets, for piecewise continuity
vector[S + 1] k_s; // actual rate in each segment
real m_pr;

// Compute the rate in each segment
k_s = append_row(k, k + cumulative_sum(delta));

// Piecewise offsets
m_pr = m; // The offset in the previous segment
for (i in 1:S) {
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
m_pr = m_pr + gamma[i]; // update for the next segment
}
return gamma;
}

vector logistic_trend(
real k,
real m,
vector delta,
vector t,
vector cap,
matrix A,
vector t_change,
int S
) {
vector[S] gamma;

gamma = logistic_gamma(k, m, delta, t_change, S);
return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
}

// Linear trend function

vector linear_trend(
real k,
real m,
vector delta,
vector t,
matrix A,
vector t_change
) {
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
}

// Flat trend function

vector flat_trend(
real m,
int T
) {
return rep_vector(m, T);
}
}

data {
int T; // Number of time periods
int<lower=1> K; // Number of regressors
vector[T] t; // Time
vector[T] cap; // Capacities for logistic trend
vector[T] y; // Time series
int S; // Number of changepoints
vector[S] t_change; // Times of trend changepoints
matrix[T,K] X; // Regressors
vector[K] sigmas; // Scale on seasonality prior
real<lower=0> tau; // Scale on changepoints prior
int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat
vector[K] s_a; // Indicator of additive features
vector[K] s_m; // Indicator of multiplicative features
}

transformed data {
matrix[T, S] A = get_changepoint_matrix(t, t_change, T, S);
matrix[T, K] X_sa = X .* rep_matrix(s_a', T);
matrix[T, K] X_sm = X .* rep_matrix(s_m', T);
}

parameters {
real k; // Base trend growth rate
real m; // Trend offset
vector[S] delta; // Trend rate adjustments
real<lower=0> sigma_obs; // Observation noise
vector[K] beta; // Regressor coefficients
}

transformed parameters {
vector[T] trend;
if (trend_indicator == 0) {
trend = linear_trend(k, m, delta, t, A, t_change);
} else if (trend_indicator == 1) {
trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
} else if (trend_indicator == 2) {
trend = flat_trend(m, T);
}
}

model {
//priors
k ~ normal(0, 5);
m ~ normal(0, 5);
delta ~ double_exponential(0, tau);
sigma_obs ~ normal(0, 0.5);
beta ~ normal(0, sigmas);

// Likelihood
y ~ normal_id_glm(
X_sa,
trend .* (1 + X_sm * beta),
beta,
sigma_obs
);
}
sd2k marked this conversation as resolved.
Show resolved Hide resolved

Loading
Loading