From 03e364ee1ed613cdfd4e2edb44782d338cdb4a07 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Thu, 14 Sep 2023 14:29:45 -0600 Subject: [PATCH] Working example! --- src/main.cpp | 54 +++++++++++++++++++++++++++- src/scikit_build_example/__init__.py | 4 +-- tests/test_basic.py | 7 ++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index 309e34d..edd0eb2 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,5 +1,10 @@ #include +#include // for py::array_t #include "barry.hpp" +#include "models/defm.hpp" +#include + +namespace py = pybind11; #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) @@ -8,7 +13,37 @@ int add(int i, int j) { return i + j; } -namespace py = pybind11; + +std::shared_ptr< defm::DEFM > new_defm( + py::array_t< int > id, + py::array_t< int > y, + py::array_t< double > x + ) { + + // std::vector id = {1, 2, 3, 4, 5}; + // std::vector y = {1, 2, 3, 4, 5}; + // std::vector x = {1.0, 2.0, 3.0, 4.0, 5.0}; + + int n_id = id.size(); + int n_y = y.size(); + int n_x = x.size(); + + std::shared_ptr< defm::DEFM > object(new defm::DEFM( + id.mutable_data(0u), + y.mutable_data(0u), + x.mutable_data(0u), + static_cast< size_t >(n_id), + static_cast< size_t >(n_y), + static_cast< size_t >(n_x), + false, + false + )); + + return object; +} + + + PYBIND11_MODULE(_core, m) { m.doc() = R"pbdoc( @@ -30,12 +65,29 @@ PYBIND11_MODULE(_core, m) { Some other explanation about the add function. )pbdoc"); + // Example with lambda function m.def("subtract", [](int i, int j) { return i - j; }, R"pbdoc( Subtract two numbers Some other explanation about the subtract function. )pbdoc"); + // Only this is necesary to expose the class + py::class_>(m, "DEFM") + // .def(py::init<>()) + .def("print", &defm::DEFM::print, R"pbdoc( + Print the object + + Some other explanation about the print function.) + )pbdoc"); + + // Example with shared_ptr + m.def("new_defm", &new_defm, R"pbdoc( + Create a new DEFM object + + Some other explanation about the new_defm function. + )pbdoc"); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/src/scikit_build_example/__init__.py b/src/scikit_build_example/__init__.py index 15188e5..a81b80b 100644 --- a/src/scikit_build_example/__init__.py +++ b/src/scikit_build_example/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from ._core import __doc__, __version__, add, subtract +from ._core import __doc__, __version__, add, subtract, new_defm -__all__ = ["__doc__", "__version__", "add", "subtract"] +__all__ = ["__doc__", "__version__", "add", "subtract", "new_defm"] diff --git a/tests/test_basic.py b/tests/test_basic.py index 93da71b..7ce9a1a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,6 +1,13 @@ from __future__ import annotations import scikit_build_example as m +import numpy as np + +y = np.array([1, 2, 3]) +x = np.array([1, 2, 3]) +id = np.array([1, 2, 3]) + +m.new_defm(y, x, id) def test_version():