Skip to content

Commit

Permalink
[bind] bind specified in #52 (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaccBarker authored Nov 6, 2024
1 parent c50b927 commit 087110f
Showing 1 changed file with 87 additions and 38 deletions.
125 changes: 87 additions & 38 deletions src/virus.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#include "virus.hpp"
#include "virus-distribute-meat.hpp"

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

using namespace epiworld;
using namespace epiworldpy;
using namespace pybind11::literals;
namespace py = pybind11;

static epiworld::Virus<int> new_virus(std::string name, double prevalence,
Expand All @@ -23,40 +28,30 @@ static epiworld::Virus<int> new_virus(std::string name, double prevalence,
return virus;
}

static std::string get_name(epiworld::Virus<int> &self) {
return self.get_name();
}
static py::dict get_queue(epiworld::Virus<int> virus) {
epiworld_fast_int init;
epiworld_fast_int end;
epiworld_fast_int removed;

static void set_name(epiworld::Virus<int> &self, std::string name) {
self.set_name(name);
}
virus.get_queue(&init, &end, &removed);

static void set_prob_infecting(epiworld::Virus<int> virus,
double prob_infecting) {
virus.set_prob_infecting(prob_infecting);
}
/* Return to Python. */
py::dict ret("init"_a = init, "end"_a = end, "removed"_a = removed);

static void set_prob_recovery(epiworld::Virus<int> virus,
double prob_recovery) {
virus.set_prob_recovery(prob_recovery);
return ret;
}

static void set_prob_death(epiworld::Virus<int> virus, double prob_death) {
virus.set_prob_death(prob_death);
}
static py::dict get_state(epiworld::Virus<int> virus) {
epiworld_fast_int init;
epiworld_fast_int end;
epiworld_fast_int removed;

static void set_incubation(epiworld::Virus<int> virus, double incubation) {
virus.set_prob_infecting(incubation);
}
virus.get_state(&init, &end, &removed);

static void set_state(epiworld::Virus<int> virus, size_t init, size_t end,
size_t removed) {
virus.set_state(init, end, removed);
}
/* Return to Python. */
py::dict ret("init"_a = init, "end"_a = end, "removed"_a = removed);

static void set_distribution(epiworld::Virus<int> virus,
VirusToAgentFun<int> fun) {
virus.set_distribution(fun);
return ret;
}

static void print_virus(epiworld::Virus<int> virus) { virus.print(); }
Expand All @@ -66,20 +61,65 @@ void epiworldpy::export_virus(pybind11::class_<epiworld::Virus<int>> &c) {
py::arg("prevalence"), py::arg("as_proportion"),
py::arg("prob_infecting"), py::arg("prob_recovery"),
py::arg("prob_death"), py::arg("post_immunity"), py::arg("incubation"))
.def("get_name", &get_name, "Get the tool name.")
.def("set_name", &set_name, "Set the tool name.", py::arg("name"))
.def("set_state", &set_state, "Set some state.", py::arg("init"),
py::arg("end"), py::arg("removed"))
.def("set_prob_infecting", &set_prob_infecting,
.def("get_name", &epiworld::Virus<int>::get_name, "Get the tool name.")
.def("set_name", &epiworld::Virus<int>::set_name, "Set the tool name.",
py::arg("name"))
.def("set_state", &epiworld::Virus<int>::set_state, "Set some state.",
py::arg("init"), py::arg("end"), py::arg("removed"))
.def("set_prob_infecting",
pybind11::detail::overload_cast_impl<epiworld_double>()(
&epiworld::Virus<int>::set_prob_infecting),
"Set the probability of infection.", py::arg("prob_infecting"))
.def("set_prob_recovery", &set_prob_recovery,
.def("set_prob_recovery",
pybind11::detail::overload_cast_impl<epiworld_double>()(
&epiworld::Virus<int>::set_prob_recovery),
"Set the probability for recovery.", py::arg("prob_recovery"))
.def("set_prob_death", &set_prob_death,
.def("set_prob_death",
pybind11::detail::overload_cast_impl<epiworld_double>()(
&epiworld::Virus<int>::set_prob_death),
"Set the probability for mortality.", py::arg("prob_death"))
.def("set_incubation", &set_incubation, "Set the incubation period.",
py::arg("incubation"))
.def("set_distribution", &set_distribution,
.def("set_incubation",
pybind11::detail::overload_cast_impl<epiworld_double>()(
&epiworld::Virus<int>::set_incubation),
"Set the incubation period.", py::arg("incubation"))
.def("set_post_recovery", &epiworld::Virus<int>::set_post_recovery, "",
py::arg("Set the post recovery."))
.def("set_post_immunity",
pybind11::detail::overload_cast_impl<epiworld_double>()(
&epiworld::Virus<int>::set_post_immunity),
"Set the post immunity.", py::arg(""))
.def("set_distribution_fun", &epiworld::Virus<int>::set_distribution,
"Set the distribution function.", py::arg("fun"))
.def("set_prob_infecting_fun",
&epiworld::Virus<int>::set_prob_infecting_fun,
"Set the probability of infection callback.", py::arg("fun"))
.def("set_prob_recovery_fun",
&epiworld::Virus<int>::set_prob_recovery_fun,
"Set the probability of recovery callback.", py::arg("fun"))
.def("set_prob_death_fun", &epiworld::Virus<int>::set_prob_death_fun,
"Set the probability of death callback.", py::arg("fun"))
.def("set_incubation_fun", &epiworld::Virus<int>::set_incubation_fun,
"Set the incubation callback.", py::arg("fun"))
.def("set_queue", &epiworld::Virus<int>::set_queue, py::arg("init"),
py::arg("end"), py::arg("removed"))
.def("set_date", &epiworld::Virus<int>::set_date, py::arg("date"))
.def("set_sequence", &epiworld::Virus<int>::set_sequence,
"Set the sequence value.", py::arg("sequence"))
.def("get_queue", &get_queue, "Get the queue.")
.def("get_state", &get_state, "Set some state.")
.def("get_incubation", &epiworld::Virus<int>::get_incubation,
"Get the incubation value.", py::arg("model"))
.def("get_prob_infecting", &epiworld::Virus<int>::get_prob_infecting,
"Get the probability of infection.", py::arg("model"))
.def("get_prob_recovery", &epiworld::Virus<int>::get_prob_recovery,
"Get the probability of recovery.", py::arg("model"))
.def("get_prob_death", &epiworld::Virus<int>::get_prob_death,
"Get the probability of death.", py::arg("model"))
.def("get_id", &epiworld::Virus<int>::get_id, "Get the ID of this virus.")
.def("get_date", &epiworld::Virus<int>::get_date, "Get the date.") // ?
.def("distribute", &epiworld::Virus<int>::distribute, py::arg("model"))
.def("post_recovery", &epiworld::Virus<int>::post_recovery,
py::arg("model"))
.def("print", &print_virus, "Print information about this virus.");
}

Expand All @@ -88,6 +128,12 @@ static VirusToAgentFun<int> new_distribution_fun(
return VirusToAgentFun<int>(fun);
}

static VirusFun<int> new_virus_fun(
std::function<epiworld_double(Agent<int> *, Virus<int> &, Model<int> *)>
&fun) {
return VirusFun<int>(fun);
}

static VirusToAgentFun<int> new_random_distribution(double prevalence,
bool as_proportion) {
return distribute_virus_randomly(prevalence);
Expand All @@ -100,9 +146,12 @@ static VirusToAgentFun<int> new_distribute_to_set(std::vector<size_t> ids) {

void epiworldpy::export_virus_to_agent_fun(
pybind11::class_<epiworld::VirusToAgentFun<int>> &c) {
c.def_static("new_distribution_fun", &new_distribution_fun,
"Create a new distribution function based off a lambda.",
c.def_static("new_virus_fun", &new_virus_fun,
"Create a new, generic, virus callback function.",
py::arg("fun"))
.def_static("new_distribution_fun", &new_distribution_fun,
"Create a new distribution function based off a lambda.",
py::arg("fun"))
.def_static("new_random_distribution", &new_random_distribution,
"Randomly infect agents in the model.")
.def_static("new_distribute_to_set", &new_distribute_to_set,
Expand Down

0 comments on commit 087110f

Please sign in to comment.