Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(wasmstan): use Stan's json_data var context #145

Merged
merged 3 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 11 additions & 89 deletions components/cpp/prophet-wasmstan/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// `prophet_wasmstan.wit` file. This effectively:
// - converts the inputs to the interface into something that Stan can
// understand,
// using a `stan::io::array_var_context` to hold data and initial parameters
// using a `stan::io::json::json_data` to hold data and initial parameters
// for Stan
// - uses the `model.hpp` file generated by `stanc` to create a new model with
// that data
Expand All @@ -21,6 +21,7 @@
#include <vector>

#include "model/model.hpp"
#include "stan/src/stan/io/json/json_data.hpp"
#include "structured_writer.cpp"
#include <stan/callbacks/interrupt.hpp>
#include <stan/callbacks/logger.hpp>
Expand All @@ -32,92 +33,6 @@
#include <stan/services/optimize/lbfgs.hpp>
#include <stan/services/optimize/newton.hpp>

// Convert the input data into a `stan::io::array_var_context` which can be used
// to create a Stan model.
stan::io::array_var_context
data_context(exports_augurs_prophet_wasmstan_optimizer_data_t *data) {
using size_vec = std::vector<size_t>;
// Create the var context. There's a constructor which accepts
// two sets of three vectors: the first set defines the names,
// values and dimensions of the `double` type data, while the
// second is similar for the `int` type data.
std::vector<std::string> names_r = {
"y", // Time series.
"t", // Timestamps.
"cap", // Capacities for logistic trend.
"t_change", // Times of trend changepoints.
"X", // Regressors.
"sigmas", // Scale on seasonality prior.
"tau", // Scale on changepoints prior.
};
std::vector<double> values_r;
// We can't use range based for loops here because these aren't
// vectors, they're just pointers to C-style arrays.
values_r.reserve(data->y.len + data->t.len + data->cap.len +
data->t_change.len + data->x.len + data->sigmas.len +
1); // Add one for tau.
for (size_t i = 0; i < data->y.len; i++) {
values_r.push_back(data->y.ptr[i]);
}
for (size_t i = 0; i < data->t.len; i++) {
values_r.push_back(data->t.ptr[i]);
}
for (size_t i = 0; i < data->cap.len; i++) {
values_r.push_back(data->cap.ptr[i]);
}
for (size_t i = 0; i < data->t_change.len; i++) {
values_r.push_back(data->t_change.ptr[i]);
}
for (size_t i = 0; i < data->x.len; i++) {
values_r.push_back(data->x.ptr[i]);
}
for (size_t i = 0; i < data->sigmas.len; i++) {
values_r.push_back(data->sigmas.ptr[i]);
}
values_r.push_back(data->tau);
std::vector<size_vec> dims_r{
size_vec{data->y.len}, // y
size_vec{data->t.len}, // t
size_vec{data->cap.len}, // cap
size_vec{data->t_change.len}, // t_change
size_vec{data->y.len, static_cast<unsigned long>(data->k)}, // X
size_vec{data->sigmas.len}, // sigmas
size_vec{}, // tau
};

std::vector<std::string> names_i = {
"T", // Number of time periods.
"S", // Number of changepoints.
"K", // Number of regressors.
"trend_indicator",
"s_a", // Indicator of additive features.
"s_m", // Indicator of multiplicative features.
};
std::vector<int> values_i;
values_i.reserve(names_i.size());
// This is `T` in the STAN model definition but WIT identifiers
// must be lower-kebab-case so we used `n` instead.
values_i.push_back(data->n);
values_i.push_back(data->s);
values_i.push_back(data->k);
values_i.push_back(data->trend_indicator);
for (size_t i = 0; i < data->s_a.len; i++) {
values_i.push_back(data->s_a.ptr[i]);
}
for (size_t i = 0; i < data->s_m.len; i++) {
values_i.push_back(data->s_m.ptr[i]);
}
std::vector<size_vec> dims_i{size_vec{}, // T
size_vec{}, // S
size_vec{}, // K
size_vec{}, // trend_indicator
size_vec{data->s_a.len}, // s_a
size_vec{data->s_m.len}}; // s_m

return stan::io::array_var_context(names_r, values_r, dims_r, names_i,
values_i, dims_i);
}

// Convert the input parameters into a `stan::io::array_var_context` which can
// be used to optimize a Stan model.
stan::io::array_var_context
Expand Down Expand Up @@ -344,6 +259,11 @@ bool store_optimized_params(
return true;
}

// See https://stackoverflow.com/questions/7781898/get-an-istream-from-a-char.
struct membuf : std::streambuf {
membuf(char *begin, size_t end) { this->setg(begin, begin, begin + end); }
};

// Optimize a Prophet model using Stan.
//
// This satisfies the `optimize` interface from the `prophet_wasmstan.wit` file.
Expand All @@ -353,7 +273,7 @@ bool store_optimized_params(
// implementation will need updating too.
bool exports_augurs_prophet_wasmstan_optimizer_optimize(
exports_augurs_prophet_wasmstan_optimizer_inits_t *init,
exports_augurs_prophet_wasmstan_optimizer_data_t *data,
exports_augurs_prophet_wasmstan_optimizer_data_json_t *data,
exports_augurs_prophet_wasmstan_optimizer_optimize_opts_t *opts,
exports_augurs_prophet_wasmstan_optimizer_optimize_output_t *ret,
prophet_wasmstan_string_t *err) {
Expand All @@ -378,7 +298,9 @@ bool exports_augurs_prophet_wasmstan_optimizer_optimize(
}

// Create the data context.
stan::io::array_var_context data_ctx = data_context(data);
membuf data_str((char *)(data->ptr), data->len);
std::istream data_stream(&data_str);
stan::json::json_data data_ctx(data_stream);
sd2k marked this conversation as resolved.
Show resolved Hide resolved

// Create a model.
std::stringstream msg_stream;
Expand Down
76 changes: 33 additions & 43 deletions components/cpp/prophet-wasmstan/prophet_wasmstan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ void augurs_prophet_wasmstan_types_data_free(augurs_prophet_wasmstan_types_data_
prophet_wasmstan_list_f64_free(&ptr->sigmas);
}

void augurs_prophet_wasmstan_types_data_json_free(augurs_prophet_wasmstan_types_data_json_t *ptr) {
prophet_wasmstan_string_free(ptr);
}

void augurs_prophet_wasmstan_types_option_algorithm_free(augurs_prophet_wasmstan_types_option_algorithm_t *ptr) {
if (ptr->is_some) {
}
Expand Down Expand Up @@ -170,8 +174,8 @@ void exports_augurs_prophet_wasmstan_optimizer_inits_free(exports_augurs_prophet
augurs_prophet_wasmstan_types_inits_free(ptr);
}

void exports_augurs_prophet_wasmstan_optimizer_data_free(exports_augurs_prophet_wasmstan_optimizer_data_t *ptr) {
augurs_prophet_wasmstan_types_data_free(ptr);
void exports_augurs_prophet_wasmstan_optimizer_data_json_free(exports_augurs_prophet_wasmstan_optimizer_data_json_t *ptr) {
augurs_prophet_wasmstan_types_data_json_free(ptr);
}

void exports_augurs_prophet_wasmstan_optimizer_optimize_opts_free(exports_augurs_prophet_wasmstan_optimizer_optimize_opts_t *ptr) {
Expand Down Expand Up @@ -217,158 +221,158 @@ static uint8_t RET_AREA[96];
__attribute__((__export_name__("augurs:prophet-wasmstan/optimizer#optimize")))
uint8_t * __wasm_export_exports_augurs_prophet_wasmstan_optimizer_optimize(uint8_t * arg) {
augurs_prophet_wasmstan_types_option_algorithm_t option;
switch ((int32_t) *((uint8_t*) (arg + 128))) {
switch ((int32_t) *((uint8_t*) (arg + 48))) {
case 0: {
option.is_some = false;
break;
}
case 1: {
option.is_some = true;
option.val = (int32_t) *((uint8_t*) (arg + 129));
option.val = (int32_t) *((uint8_t*) (arg + 49));
break;
}
}
prophet_wasmstan_option_u32_t option0;
switch ((int32_t) *((uint8_t*) (arg + 132))) {
switch ((int32_t) *((uint8_t*) (arg + 52))) {
case 0: {
option0.is_some = false;
break;
}
case 1: {
option0.is_some = true;
option0.val = (uint32_t) (*((int32_t*) (arg + 136)));
option0.val = (uint32_t) (*((int32_t*) (arg + 56)));
break;
}
}
prophet_wasmstan_option_u32_t option1;
switch ((int32_t) *((uint8_t*) (arg + 140))) {
switch ((int32_t) *((uint8_t*) (arg + 60))) {
case 0: {
option1.is_some = false;
break;
}
case 1: {
option1.is_some = true;
option1.val = (uint32_t) (*((int32_t*) (arg + 144)));
option1.val = (uint32_t) (*((int32_t*) (arg + 64)));
break;
}
}
prophet_wasmstan_option_f64_t option2;
switch ((int32_t) *((uint8_t*) (arg + 152))) {
switch ((int32_t) *((uint8_t*) (arg + 72))) {
case 0: {
option2.is_some = false;
break;
}
case 1: {
option2.is_some = true;
option2.val = *((double*) (arg + 160));
option2.val = *((double*) (arg + 80));
break;
}
}
prophet_wasmstan_option_f64_t option3;
switch ((int32_t) *((uint8_t*) (arg + 168))) {
switch ((int32_t) *((uint8_t*) (arg + 88))) {
case 0: {
option3.is_some = false;
break;
}
case 1: {
option3.is_some = true;
option3.val = *((double*) (arg + 176));
option3.val = *((double*) (arg + 96));
break;
}
}
prophet_wasmstan_option_f64_t option4;
switch ((int32_t) *((uint8_t*) (arg + 184))) {
switch ((int32_t) *((uint8_t*) (arg + 104))) {
case 0: {
option4.is_some = false;
break;
}
case 1: {
option4.is_some = true;
option4.val = *((double*) (arg + 192));
option4.val = *((double*) (arg + 112));
break;
}
}
prophet_wasmstan_option_f64_t option5;
switch ((int32_t) *((uint8_t*) (arg + 200))) {
switch ((int32_t) *((uint8_t*) (arg + 120))) {
case 0: {
option5.is_some = false;
break;
}
case 1: {
option5.is_some = true;
option5.val = *((double*) (arg + 208));
option5.val = *((double*) (arg + 128));
break;
}
}
prophet_wasmstan_option_f64_t option6;
switch ((int32_t) *((uint8_t*) (arg + 216))) {
switch ((int32_t) *((uint8_t*) (arg + 136))) {
case 0: {
option6.is_some = false;
break;
}
case 1: {
option6.is_some = true;
option6.val = *((double*) (arg + 224));
option6.val = *((double*) (arg + 144));
break;
}
}
prophet_wasmstan_option_f64_t option7;
switch ((int32_t) *((uint8_t*) (arg + 232))) {
switch ((int32_t) *((uint8_t*) (arg + 152))) {
case 0: {
option7.is_some = false;
break;
}
case 1: {
option7.is_some = true;
option7.val = *((double*) (arg + 240));
option7.val = *((double*) (arg + 160));
break;
}
}
prophet_wasmstan_option_u32_t option8;
switch ((int32_t) *((uint8_t*) (arg + 248))) {
switch ((int32_t) *((uint8_t*) (arg + 168))) {
case 0: {
option8.is_some = false;
break;
}
case 1: {
option8.is_some = true;
option8.val = (uint32_t) (*((int32_t*) (arg + 252)));
option8.val = (uint32_t) (*((int32_t*) (arg + 172)));
break;
}
}
prophet_wasmstan_option_u32_t option9;
switch ((int32_t) *((uint8_t*) (arg + 256))) {
switch ((int32_t) *((uint8_t*) (arg + 176))) {
case 0: {
option9.is_some = false;
break;
}
case 1: {
option9.is_some = true;
option9.val = (uint32_t) (*((int32_t*) (arg + 260)));
option9.val = (uint32_t) (*((int32_t*) (arg + 180)));
break;
}
}
prophet_wasmstan_option_bool_t option10;
switch ((int32_t) *((uint8_t*) (arg + 264))) {
switch ((int32_t) *((uint8_t*) (arg + 184))) {
case 0: {
option10.is_some = false;
break;
}
case 1: {
option10.is_some = true;
option10.val = (int32_t) *((uint8_t*) (arg + 265));
option10.val = (int32_t) *((uint8_t*) (arg + 185));
break;
}
}
prophet_wasmstan_option_u32_t option11;
switch ((int32_t) *((uint8_t*) (arg + 268))) {
switch ((int32_t) *((uint8_t*) (arg + 188))) {
case 0: {
option11.is_some = false;
break;
}
case 1: {
option11.is_some = true;
option11.val = (uint32_t) (*((int32_t*) (arg + 272)));
option11.val = (uint32_t) (*((int32_t*) (arg + 192)));
sd2k marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
Expand All @@ -379,21 +383,7 @@ uint8_t * __wasm_export_exports_augurs_prophet_wasmstan_optimizer_optimize(uint8
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 24))), (*((size_t*) (arg + 28))) },
(double) *((double*) (arg + 32)),
};
exports_augurs_prophet_wasmstan_optimizer_data_t arg13 = (augurs_prophet_wasmstan_types_data_t) {
(int32_t) *((int32_t*) (arg + 40)),
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 44))), (*((size_t*) (arg + 48))) },
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 52))), (*((size_t*) (arg + 56))) },
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 60))), (*((size_t*) (arg + 64))) },
(int32_t) *((int32_t*) (arg + 68)),
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 72))), (*((size_t*) (arg + 76))) },
(augurs_prophet_wasmstan_types_trend_indicator_t) (int32_t) *((uint8_t*) (arg + 80)),
(int32_t) *((int32_t*) (arg + 84)),
(prophet_wasmstan_list_s32_t) (prophet_wasmstan_list_s32_t) { (int32_t*)(*((uint8_t **) (arg + 88))), (*((size_t*) (arg + 92))) },
(prophet_wasmstan_list_s32_t) (prophet_wasmstan_list_s32_t) { (int32_t*)(*((uint8_t **) (arg + 96))), (*((size_t*) (arg + 100))) },
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 104))), (*((size_t*) (arg + 108))) },
(prophet_wasmstan_list_f64_t) (prophet_wasmstan_list_f64_t) { (double*)(*((uint8_t **) (arg + 112))), (*((size_t*) (arg + 116))) },
(double) *((double*) (arg + 120)),
};
exports_augurs_prophet_wasmstan_optimizer_data_json_t arg13 = (prophet_wasmstan_string_t) { (uint8_t*)(*((uint8_t **) (arg + 40))), (*((size_t*) (arg + 44))) };
exports_augurs_prophet_wasmstan_optimizer_optimize_opts_t arg14 = (augurs_prophet_wasmstan_types_optimize_opts_t) {
(augurs_prophet_wasmstan_types_option_algorithm_t) option,
(prophet_wasmstan_option_u32_t) option0,
Expand Down
Loading
Loading