Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup the flaxlib in C++, using Meson and Nanobind. #4380

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,10 @@ jobs:
uses: astral-sh/setup-uv@v2
with:
version: "0.3.0"
- name: Setup Rust (flaxlib)
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Install dependencies
run: |
uv sync --extra all --extra testing --extra docs
uv pip install ./flaxlib
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ build/
docs*/**/_autosummary
docs*/_build
docs*/**/tmp
flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects

# used by direnv
.envrc
Expand Down
1 change: 0 additions & 1 deletion flaxlib/README.md

This file was deleted.

15 changes: 0 additions & 15 deletions flaxlib/flaxlib/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
34 changes: 34 additions & 0 deletions flaxlib_src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# flaxlib

## Build flaxlib from source

Install necessary dependencies to build the C++ based package.

```shell
pip install meson-python ninja build
```

Clone the Flax repository, navigate to the flaxlib source directory.

```shell
git clone [email protected]:google/flax.git
cd flax/flaxlib_src
```

Configure the build.

```shell
mkdir -p subprojects
meson wrap install robin-map
meson wrap install nanobind
meson setup builddir
```

Compile the code. You'll need to run this repeatedly if you modify the source
code. Note that the actual wheel name will differ depending on your system.

```shell
meson compile -C builddir
python -m build . -w
pip install dist/flaxlib-0.0.1-cp311-cp311-macosx_14_0_arm64.whl --force-reinstall
```
File renamed without changes.
14 changes: 14 additions & 0 deletions flaxlib_src/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
project(
'flaxlib',
'cpp',
version: '0.0.1',
default_options: ['cpp_std=c++17'],
)
py = import('python').find_installation()
nanobind_dep = dependency('nanobind', static: true)
py.extension_module(
'flaxlib',
sources: ['src/lib.cc'],
dependencies: [nanobind_dep],
install: true,
)
8 changes: 3 additions & 5 deletions flaxlib/pyproject.toml → flaxlib_src/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[build-system]
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
requires = ['meson-python']
build-backend = 'mesonpy'

[project]
name = "flaxlib"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: C++",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
Expand All @@ -15,5 +15,3 @@ dynamic = ["version"]
tests = [
"pytest",
]
[tool.maturin]
features = ["pyo3/extension-module"]
14 changes: 14 additions & 0 deletions flaxlib_src/src/lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <string>

#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h"

namespace flaxlib {
std::string sum_as_string(int a, int b) {
return std::to_string(a + b);
}

NB_MODULE(flaxlib, m) {
m.def("sum_as_string", &sum_as_string);
}
} // namespace flaxlib
File renamed without changes.
File renamed without changes.
13 changes: 8 additions & 5 deletions tests/flaxlib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
import flaxlib

# TODO: Re-enable this test after setting up CI build for flaxlib CC.

class TestFlaxlib(absltest.TestCase):
# from absl.testing import absltest
# import flaxlib

def test_flaxlib(self):
self.assertEqual(flaxlib.sum_as_string(1, 2), '3')

# class TestFlaxlib(absltest.TestCase):

# def test_flaxlib(self):
# self.assertEqual(flaxlib.sum_as_string(1, 2), '3')
Loading