From 4aaac1fcbdd4b11f1004c6124cb8c7d893c2d857 Mon Sep 17 00:00:00 2001 From: Stephen Nicholas Swatman Date: Mon, 26 Aug 2024 14:58:14 +0200 Subject: [PATCH] feat: Add B-field accessors to Python bindings As #3479 reveals, we don't currently have any clean, cache-aware ways of accessing B-fields in Python code. In order to avoid hacks, this commit adds the necessary bindings to allow us to cleanly access B-fields with cache objects. --- Examples/Python/src/Base.cpp | 7 ++++ Examples/Python/src/MagneticField.cpp | 20 +++++++++- Examples/Python/src/ModuleEntry.cpp | 2 + Examples/Python/tests/test_magnetic_field.py | 39 +++++++++++++++++++- 4 files changed, 66 insertions(+), 2 deletions(-) diff --git a/Examples/Python/src/Base.cpp b/Examples/Python/src/Base.cpp index e720857c4f0..b6582c15d1f 100644 --- a/Examples/Python/src/Base.cpp +++ b/Examples/Python/src/Base.cpp @@ -12,6 +12,7 @@ #include "Acts/Geometry/GeometryContext.hpp" #include "Acts/MagneticField/MagneticFieldContext.hpp" #include "Acts/Plugins/Python/Utilities.hpp" +#include "Acts/Utilities/Any.hpp" #include "Acts/Utilities/AxisFwd.hpp" #include "Acts/Utilities/BinningData.hpp" #include "Acts/Utilities/CalibrationContext.hpp" @@ -42,6 +43,12 @@ void addContext(Context& ctx) { .def(py::init<>()); } +void addAny(Context& ctx) { + auto& m = ctx.get("main"); + + py::class_>(m, "AnyBase512").def(py::init<>()); +} + void addUnits(Context& ctx) { auto& m = ctx.get("main"); auto u = m.def_submodule("UnitConstants"); diff --git a/Examples/Python/src/MagneticField.cpp b/Examples/Python/src/MagneticField.cpp index a8279675696..bff62ec0e80 100644 --- a/Examples/Python/src/MagneticField.cpp +++ b/Examples/Python/src/MagneticField.cpp @@ -36,12 +36,30 @@ using namespace pybind11::literals; namespace Acts::Python { +/// @brief Get the value of a field, throwing an exception if the result is +/// invalid. +Acts::Vector3 getField(Acts::MagneticFieldProvider& self, + const Acts::Vector3& position, + Acts::MagneticFieldProvider::Cache& cache) { + if (Result res = self.getField(position, cache); !res.ok()) { + std::stringstream ss; + + ss << "Field lookup failure with error: \"" << res.error() << "\""; + + throw std::runtime_error{ss.str()}; + } else { + return *res; + } +} + void addMagneticField(Context& ctx) { auto [m, mex, prop] = ctx.get("main", "examples", "propagation"); py::class_>( - m, "MagneticFieldProvider"); + m, "MagneticFieldProvider") + .def("getField", &getField) + .def("makeCache", &Acts::MagneticFieldProvider::makeCache); py::class_>( diff --git a/Examples/Python/src/ModuleEntry.cpp b/Examples/Python/src/ModuleEntry.cpp index 6ea2a964a9a..53e5d0dd3af 100644 --- a/Examples/Python/src/ModuleEntry.cpp +++ b/Examples/Python/src/ModuleEntry.cpp @@ -40,6 +40,7 @@ using namespace Acts::Python; namespace Acts::Python { void addContext(Context& ctx); +void addAny(Context& ctx); void addUnits(Context& ctx); void addFramework(Context& ctx); void addLogging(Context& ctx); @@ -108,6 +109,7 @@ PYBIND11_MODULE(ActsPythonBindings, m) { } addContext(ctx); + addAny(ctx); addUnits(ctx); addFramework(ctx); addLogging(ctx); diff --git a/Examples/Python/tests/test_magnetic_field.py b/Examples/Python/tests/test_magnetic_field.py index 3c1aae5c537..a51ebd55ab6 100644 --- a/Examples/Python/tests/test_magnetic_field.py +++ b/Examples/Python/tests/test_magnetic_field.py @@ -1,4 +1,5 @@ import pytest +import random import acts import acts.examples @@ -7,16 +8,52 @@ def test_null_bfield(): - assert acts.NullBField() + nb = acts.NullBField() + assert nb + + ct = acts.MagneticFieldContext() + assert ct + + fc = nb.makeCache(ct) + assert fc + + for i in range(100): + x = random.uniform(-10000.0, 10000.0) + y = random.uniform(-10000.0, 10000.0) + z = random.uniform(-10000.0, 10000.0) + + rv = nb.getField(acts.Vector3(x, y, z), fc) + + assert rv[0] == pytest.approx(0.0) + assert rv[1] == pytest.approx(0.0) + assert rv[2] == pytest.approx(0.0) def test_constant_bfield(): with pytest.raises(TypeError): acts.ConstantBField() + v = acts.Vector3(1, 2, 3) cb = acts.ConstantBField(v) assert cb + ct = acts.MagneticFieldContext() + assert ct + + fc = cb.makeCache(ct) + assert fc + + for i in range(100): + x = random.uniform(-10000.0, 10000.0) + y = random.uniform(-10000.0, 10000.0) + z = random.uniform(-10000.0, 10000.0) + + rv = cb.getField(acts.Vector3(x, y, z), fc) + + assert rv[0] == pytest.approx(1.0) + assert rv[1] == pytest.approx(2.0) + assert rv[2] == pytest.approx(3.0) + def test_solenoid(conf_const): solenoid = conf_const(