Skip to content

Commit

Permalink
Working example!
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Sep 14, 2023
1 parent a3e9845 commit 03e364e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
54 changes: 53 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h> // for py::array_t
#include "barry.hpp"
#include "models/defm.hpp"
#include <vector>

namespace py = pybind11;

#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
Expand All @@ -8,7 +13,37 @@ int add(int i, int j) {
return i + j;
}

namespace py = pybind11;

std::shared_ptr< defm::DEFM > new_defm(
py::array_t< int > id,
py::array_t< int > y,
py::array_t< double > x
) {

// std::vector<int> id = {1, 2, 3, 4, 5};
// std::vector<int> y = {1, 2, 3, 4, 5};
// std::vector<double> x = {1.0, 2.0, 3.0, 4.0, 5.0};

int n_id = id.size();
int n_y = y.size();
int n_x = x.size();

std::shared_ptr< defm::DEFM > object(new defm::DEFM(
id.mutable_data(0u),
y.mutable_data(0u),
x.mutable_data(0u),
static_cast< size_t >(n_id),
static_cast< size_t >(n_y),
static_cast< size_t >(n_x),
false,
false
));

return object;
}




PYBIND11_MODULE(_core, m) {
m.doc() = R"pbdoc(
Expand All @@ -30,12 +65,29 @@ PYBIND11_MODULE(_core, m) {
Some other explanation about the add function.
)pbdoc");

// Example with lambda function
m.def("subtract", [](int i, int j) { return i - j; }, R"pbdoc(
Subtract two numbers
Some other explanation about the subtract function.
)pbdoc");

// Only this is necesary to expose the class
py::class_<defm::DEFM, std::shared_ptr<defm::DEFM>>(m, "DEFM")
// .def(py::init<>())
.def("print", &defm::DEFM::print, R"pbdoc(
Print the object
Some other explanation about the print function.)
)pbdoc");

// Example with shared_ptr
m.def("new_defm", &new_defm, R"pbdoc(
Create a new DEFM object
Some other explanation about the new_defm function.
)pbdoc");

#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
#else
Expand Down
4 changes: 2 additions & 2 deletions src/scikit_build_example/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from ._core import __doc__, __version__, add, subtract
from ._core import __doc__, __version__, add, subtract, new_defm

__all__ = ["__doc__", "__version__", "add", "subtract"]
__all__ = ["__doc__", "__version__", "add", "subtract", "new_defm"]
7 changes: 7 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from __future__ import annotations

import scikit_build_example as m
import numpy as np

y = np.array([1, 2, 3])
x = np.array([1, 2, 3])
id = np.array([1, 2, 3])

m.new_defm(y, x, id)


def test_version():
Expand Down

0 comments on commit 03e364e

Please sign in to comment.