From 0abb551685f56aaf328277bf312b1cf177022ac2 Mon Sep 17 00:00:00 2001 From: "Garth N. Wells" Date: Tue, 9 Jan 2024 17:20:27 +0000 Subject: [PATCH 1/5] Make access to C++ form classes from Python simpler --- python/demo/demo_static-condensation.py | 23 +++--------- python/dolfinx/fem/forms.py | 35 +++++++++++++------ .../test/unit/fem/test_custom_jit_kernels.py | 22 +++++------- 3 files changed, 38 insertions(+), 42 deletions(-) diff --git a/python/demo/demo_static-condensation.py b/python/demo/demo_static-condensation.py index 57219aaad22..bffe1ba02d2 100644 --- a/python/demo/demo_static-condensation.py +++ b/python/demo/demo_static-condensation.py @@ -32,9 +32,7 @@ import ufl from basix.ufl import element -from dolfinx import default_real_type, geometry -from dolfinx.cpp.fem import (Form_complex64, Form_complex128, Form_float32, - Form_float64) +from dolfinx import geometry from dolfinx.fem import (Form, Function, IntegralType, dirichletbc, form, functionspace, locate_dofs_topological) from dolfinx.fem.petsc import (apply_lifting, assemble_matrix, assemble_vector, @@ -42,9 +40,10 @@ from dolfinx.io import XDMFFile from dolfinx.jit import ffcx_jit from dolfinx.mesh import locate_entities_boundary, meshtags -from ffcx.codegeneration.utils import numba_ufcx_kernel_signature as ufcx_signature +from ffcx.codegeneration.utils import \ + numba_ufcx_kernel_signature as ufcx_signature -if default_real_type == np.float32: +if PETSc.RealType == np.float32: print("float32 not yet supported for this demo.") exit(0) @@ -139,25 +138,13 @@ def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL): # Prepare a Form with a condensed tabulation kernel -formtype = None -if PETSc.ScalarType == np.float32: # type: ignore - formtype = Form_float32 -elif PETSc.ScalarType == np.float64: # type: ignore - formtype = Form_float64 -elif PETSc.ScalarType == np.complex64: # type: ignore - formtype = Form_complex64 -elif PETSc.ScalarType == np.complex128: # type: ignore - formtype = Form_complex128 -else: - raise RuntimeError(f"Unsupported PETSc ScalarType '{PETSc.ScalarType}'.") # type: ignore - +formtype = Form.cpp_class(PETSc.ScalarType) # type: ignore cells = np.arange(msh.topology.index_map(msh.topology.dim).size_local) integrals = {IntegralType.cell: [(-1, tabulate_A.address, cells)]} a_cond = Form(formtype([U._cpp_object, U._cpp_object], integrals, [], [], False, None)) A_cond = assemble_matrix(a_cond, bcs=[bc]) A_cond.assemble() - b = assemble_vector(b1) apply_lifting(b, [a_cond], bcs=[[bc]]) b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE) # type: ignore diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index d78c8f01ca2..e535e6c1d2a 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -81,6 +81,30 @@ def integral_types(self): """Integral types in the form""" return self._cpp_object.integral_types + @staticmethod + def cpp_class(dtype: npt.DTypeLike) -> typing.Union[_cpp.fem.Form_complex64, + _cpp.fem.Form_complex128, + _cpp.fem.Form_float32, + _cpp.fem.Form_float64]: + """Return the wrapped C++ class of a variational form of a specific scalar type. + + Args: + dtype: Scalar type of the required form class. + + Returns: + Wrapped C++ form class of the requested type. + """ + if dtype == np.float32: + return _cpp.fem.Form_float32 + elif dtype == np.float64: + return _cpp.fem.Form_float64 + elif dtype == np.complex64: + return _cpp.fem.Form_complex64 + elif dtype == np.complex128: + return _cpp.fem.Form_complex128 + else: + raise NotImplementedError(f"Type {dtype} not supported.") + _ufl_to_dolfinx_domain = {"cell": IntegralType.cell, "exterior_facet": IntegralType.exterior_facet, @@ -115,16 +139,7 @@ def form(form: typing.Union[ufl.Form, typing.Iterable[ufl.Form]], form_compiler_options = dict() form_compiler_options["scalar_type"] = dtype - if dtype == np.float32: - ftype = _cpp.fem.Form_float32 - elif dtype == np.float64: - ftype = _cpp.fem.Form_float64 - elif dtype == np.complex64: - ftype = _cpp.fem.Form_complex64 - elif dtype == np.complex128: - ftype = _cpp.fem.Form_complex128 - else: - raise NotImplementedError(f"Type {dtype} not supported.") + ftype = Form.cpp_class(dtype) def _form(form): """Compile a single UFL form""" diff --git a/python/test/unit/fem/test_custom_jit_kernels.py b/python/test/unit/fem/test_custom_jit_kernels.py index 27f6e01bd6f..0cad7dfea04 100644 --- a/python/test/unit/fem/test_custom_jit_kernels.py +++ b/python/test/unit/fem/test_custom_jit_kernels.py @@ -78,13 +78,9 @@ def tabulate(b_, w_, c_, coords_, local_index, orientation): return tabulate -@pytest.mark.parametrize("dtype,formtype", [ - (np.float32, _cpp.fem.Form_float32), - (np.float64, _cpp.fem.Form_float64), - (np.complex64, _cpp.fem.Form_complex64), - (np.complex128, _cpp.fem.Form_complex128) -]) -def test_numba_assembly(dtype, formtype): +@pytest.mark.parametrize("dtype", [np.float32, np.float64, + np.complex64, np.complex128]) +def test_numba_assembly(dtype): xdtype = np.real(dtype(0)).dtype k2 = tabulate_rank2(dtype, xdtype) k1 = tabulate_rank1(dtype, xdtype) @@ -94,6 +90,7 @@ def test_numba_assembly(dtype, formtype): integrals = {IntegralType.cell: [(-1, k2.address, cells), (12, k2.address, np.arange(0)), (2, k2.address, np.arange(0))]} + formtype = Form.cpp_class(dtype) a = Form(formtype([V._cpp_object, V._cpp_object], integrals, [], [], False, None)) integrals = {IntegralType.cell: [(-1, k1.address, cells)]} L = Form(formtype([V._cpp_object], integrals, [], [], False, None)) @@ -111,13 +108,9 @@ def test_numba_assembly(dtype, formtype): list_timings(MPI.COMM_WORLD, [TimingType.wall]) -@pytest.mark.parametrize("dtype,formtype", [ - (np.float32, _cpp.fem.Form_float32), - (np.float64, _cpp.fem.Form_float64), - (np.complex64, _cpp.fem.Form_complex64), - (np.complex128, _cpp.fem.Form_complex128) -]) -def test_coefficient(dtype, formtype): +@pytest.mark.parametrize("dtype", [np.float32, np.float64, + np.complex64, np.complex128]) +def test_coefficient(dtype): xdtype = np.real(dtype(0)).dtype k1 = tabulate_rank1_coeff(dtype, xdtype) @@ -130,6 +123,7 @@ def test_coefficient(dtype, formtype): tdim = mesh.topology.dim num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts integrals = {IntegralType.cell: [(1, k1.address, np.arange(num_cells, dtype=np.int32))]} + formtype = Form.cpp_class(dtype) L = Form(formtype([V._cpp_object], integrals, [vals._cpp_object], [], False, None)) b = dolfinx.fem.assemble_vector(L) From fa0ba59dd503801c976ef3bde15a65e1d6d4a966 Mon Sep 17 00:00:00 2001 From: "Garth N. Wells" Date: Tue, 9 Jan 2024 17:29:41 +0000 Subject: [PATCH 2/5] Type ignore --- python/demo/demo_static-condensation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/demo/demo_static-condensation.py b/python/demo/demo_static-condensation.py index bffe1ba02d2..ddc4edb2741 100644 --- a/python/demo/demo_static-condensation.py +++ b/python/demo/demo_static-condensation.py @@ -43,7 +43,7 @@ from ffcx.codegeneration.utils import \ numba_ufcx_kernel_signature as ufcx_signature -if PETSc.RealType == np.float32: +if PETSc.RealType == np.float32: # type: ignore print("float32 not yet supported for this demo.") exit(0) From 036a775daecc2202e7af273122926d3ec114835e Mon Sep 17 00:00:00 2001 From: "Garth N. Wells" Date: Thu, 18 Jan 2024 19:24:22 +0000 Subject: [PATCH 3/5] Make static function a free function --- python/demo/demo_static-condensation.py | 4 +- python/dolfinx/__init__.py | 3 +- python/dolfinx/fem/__init__.py | 2 +- python/dolfinx/fem/forms.py | 48 ++++++++++--------- python/dolfinx/utils.py | 4 +- .../test/unit/fem/test_custom_jit_kernels.py | 6 +-- 6 files changed, 37 insertions(+), 30 deletions(-) diff --git a/python/demo/demo_static-condensation.py b/python/demo/demo_static-condensation.py index ddc4edb2741..2a662097143 100644 --- a/python/demo/demo_static-condensation.py +++ b/python/demo/demo_static-condensation.py @@ -34,7 +34,7 @@ from basix.ufl import element from dolfinx import geometry from dolfinx.fem import (Form, Function, IntegralType, dirichletbc, form, - functionspace, locate_dofs_topological) + functionspace, locate_dofs_topological, form_cpp_class) from dolfinx.fem.petsc import (apply_lifting, assemble_matrix, assemble_vector, set_bc) from dolfinx.io import XDMFFile @@ -138,7 +138,7 @@ def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL): # Prepare a Form with a condensed tabulation kernel -formtype = Form.cpp_class(PETSc.ScalarType) # type: ignore +formtype = form_cpp_class(PETSc.ScalarType) # type: ignore cells = np.arange(msh.topology.index_map(msh.topology.dim).size_local) integrals = {IntegralType.cell: [(-1, tabulate_A.address, cells)]} a_cond = Form(formtype([U._cpp_object, U._cpp_object], integrals, [], [], False, None)) diff --git a/python/dolfinx/__init__.py b/python/dolfinx/__init__.py index b91d7691f4d..8e789acb561 100644 --- a/python/dolfinx/__init__.py +++ b/python/dolfinx/__init__.py @@ -20,7 +20,8 @@ from dolfinx import common from dolfinx import cpp as _cpp -from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot, utils +from dolfinx import (fem, geometry, graph, io, jit, la, log, mesh, nls, plot, + utils) # Initialise logging from dolfinx.common import (TimingType, git_commit_hash, has_debug, has_kahip, has_parmetis, list_timings, timing) diff --git a/python/dolfinx/fem/__init__.py b/python/dolfinx/fem/__init__.py index 5f2a4e7d4b2..613def40e97 100644 --- a/python/dolfinx/fem/__init__.py +++ b/python/dolfinx/fem/__init__.py @@ -46,4 +46,4 @@ def create_sparsity_pattern(a: Form): "form", "IntegralType", "create_vector", "locate_dofs_geometrical", "locate_dofs_topological", "extract_function_spaces", "transpose_dofmap", "create_nonmatching_meshes_interpolation_data", - "CoordinateElement", "coordinate_element"] + "CoordinateElement", "coordinate_element", "form_cpp_class"] diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index e535e6c1d2a..877672b24ec 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -81,29 +81,33 @@ def integral_types(self): """Integral types in the form""" return self._cpp_object.integral_types - @staticmethod - def cpp_class(dtype: npt.DTypeLike) -> typing.Union[_cpp.fem.Form_complex64, - _cpp.fem.Form_complex128, - _cpp.fem.Form_float32, - _cpp.fem.Form_float64]: - """Return the wrapped C++ class of a variational form of a specific scalar type. - Args: - dtype: Scalar type of the required form class. +def form_cpp_class(dtype: npt.DTypeLike) -> typing.Union[_cpp.fem.Form_float32, + _cpp.fem.Form_float64, + _cpp.fem.Form_complex64, + _cpp.fem.Form_complex128]: + """Return the wrapped C++ class of a variational form of a specific scalar type. - Returns: - Wrapped C++ form class of the requested type. - """ - if dtype == np.float32: - return _cpp.fem.Form_float32 - elif dtype == np.float64: - return _cpp.fem.Form_float64 - elif dtype == np.complex64: - return _cpp.fem.Form_complex64 - elif dtype == np.complex128: - return _cpp.fem.Form_complex128 - else: - raise NotImplementedError(f"Type {dtype} not supported.") + Args: + dtype: Scalar type of the required form class. + + Returns: + Wrapped C++ form class of the requested type. + + Note: + This function is for advanced usage, typically when writing + custom kernels using Numba or C. + """ + if dtype == np.float32: + return _cpp.fem.Form_float32 + elif dtype == np.float64: + return _cpp.fem.Form_float64 + elif dtype == np.complex64: + return _cpp.fem.Form_complex64 + elif dtype == np.complex128: + return _cpp.fem.Form_complex128 + else: + raise NotImplementedError(f"Type {dtype} not supported.") _ufl_to_dolfinx_domain = {"cell": IntegralType.cell, @@ -139,7 +143,7 @@ def form(form: typing.Union[ufl.Form, typing.Iterable[ufl.Form]], form_compiler_options = dict() form_compiler_options["scalar_type"] = dtype - ftype = Form.cpp_class(dtype) + ftype = form_cpp_class(dtype) def _form(form): """Compile a single UFL form""" diff --git a/python/dolfinx/utils.py b/python/dolfinx/utils.py index e5a6f91d674..af16805fb53 100644 --- a/python/dolfinx/utils.py +++ b/python/dolfinx/utils.py @@ -69,6 +69,7 @@ def set_vals(A: int, """ try: import petsc4py.PETSc as _PETSc + import llvmlite as _llvmlite import numba as _numba _llvmlite.binding.load_library_permanently(str(get_petsc_lib())) @@ -157,10 +158,11 @@ def set_vals(A: int, ffi.from_buffer(rows(data), mode) """ try: + from petsc4py import PETSc as _PETSc + import cffi as _cffi import numba as _numba import numba.core.typing.cffi_utils as _cffi_support - from petsc4py import PETSc as _PETSc # Register complex types _ffi = _cffi.FFI() diff --git a/python/test/unit/fem/test_custom_jit_kernels.py b/python/test/unit/fem/test_custom_jit_kernels.py index 0cad7dfea04..fbf977be733 100644 --- a/python/test/unit/fem/test_custom_jit_kernels.py +++ b/python/test/unit/fem/test_custom_jit_kernels.py @@ -19,7 +19,7 @@ from dolfinx import TimingType from dolfinx import cpp as _cpp from dolfinx import fem, la, list_timings -from dolfinx.fem import Form, Function, IntegralType, functionspace +from dolfinx.fem import Form, Function, IntegralType, functionspace, form_cpp_class from dolfinx.mesh import create_unit_square import ffcx.codegeneration.utils @@ -90,7 +90,7 @@ def test_numba_assembly(dtype): integrals = {IntegralType.cell: [(-1, k2.address, cells), (12, k2.address, np.arange(0)), (2, k2.address, np.arange(0))]} - formtype = Form.cpp_class(dtype) + formtype = form_cpp_class(dtype) a = Form(formtype([V._cpp_object, V._cpp_object], integrals, [], [], False, None)) integrals = {IntegralType.cell: [(-1, k1.address, cells)]} L = Form(formtype([V._cpp_object], integrals, [], [], False, None)) @@ -123,7 +123,7 @@ def test_coefficient(dtype): tdim = mesh.topology.dim num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts integrals = {IntegralType.cell: [(1, k1.address, np.arange(num_cells, dtype=np.int32))]} - formtype = Form.cpp_class(dtype) + formtype = form_cpp_class(dtype) L = Form(formtype([V._cpp_object], integrals, [vals._cpp_object], [], False, None)) b = dolfinx.fem.assemble_vector(L) From e17cda09ae6fa36ed140342d5e6d442fbd341ce5 Mon Sep 17 00:00:00 2001 From: "Garth N. Wells" Date: Thu, 18 Jan 2024 19:26:24 +0000 Subject: [PATCH 4/5] Tidy up --- python/test/unit/fem/test_custom_jit_kernels.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/test/unit/fem/test_custom_jit_kernels.py b/python/test/unit/fem/test_custom_jit_kernels.py index fbf977be733..a2a5bbd5972 100644 --- a/python/test/unit/fem/test_custom_jit_kernels.py +++ b/python/test/unit/fem/test_custom_jit_kernels.py @@ -78,8 +78,7 @@ def tabulate(b_, w_, c_, coords_, local_index, orientation): return tabulate -@pytest.mark.parametrize("dtype", [np.float32, np.float64, - np.complex64, np.complex128]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) def test_numba_assembly(dtype): xdtype = np.real(dtype(0)).dtype k2 = tabulate_rank2(dtype, xdtype) @@ -108,8 +107,7 @@ def test_numba_assembly(dtype): list_timings(MPI.COMM_WORLD, [TimingType.wall]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64, - np.complex64, np.complex128]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) def test_coefficient(dtype): xdtype = np.real(dtype(0)).dtype k1 = tabulate_rank1_coeff(dtype, xdtype) From 480a3672268a065316945d1b89c468d915e3e945 Mon Sep 17 00:00:00 2001 From: "Garth N. Wells" Date: Thu, 18 Jan 2024 22:25:01 +0000 Subject: [PATCH 5/5] Add import --- python/dolfinx/fem/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/dolfinx/fem/__init__.py b/python/dolfinx/fem/__init__.py index 613def40e97..6abf5071f2c 100644 --- a/python/dolfinx/fem/__init__.py +++ b/python/dolfinx/fem/__init__.py @@ -16,7 +16,8 @@ locate_dofs_geometrical, locate_dofs_topological) from dolfinx.fem.dofmap import DofMap from dolfinx.fem.element import CoordinateElement, coordinate_element -from dolfinx.fem.forms import Form, extract_function_spaces, form +from dolfinx.fem.forms import (Form, extract_function_spaces, form, + form_cpp_class) from dolfinx.fem.function import (Constant, ElementMetaData, Expression, Function, FunctionSpace, functionspace)