diff --git a/python/demo/demo_static-condensation.py b/python/demo/demo_static-condensation.py index 57219aaad22..2a662097143 100644 --- a/python/demo/demo_static-condensation.py +++ b/python/demo/demo_static-condensation.py @@ -32,19 +32,18 @@ import ufl from basix.ufl import element -from dolfinx import default_real_type, geometry -from dolfinx.cpp.fem import (Form_complex64, Form_complex128, Form_float32, - Form_float64) +from dolfinx import geometry from dolfinx.fem import (Form, Function, IntegralType, dirichletbc, form, - functionspace, locate_dofs_topological) + functionspace, locate_dofs_topological, form_cpp_class) from dolfinx.fem.petsc import (apply_lifting, assemble_matrix, assemble_vector, set_bc) from dolfinx.io import XDMFFile from dolfinx.jit import ffcx_jit from dolfinx.mesh import locate_entities_boundary, meshtags -from ffcx.codegeneration.utils import numba_ufcx_kernel_signature as ufcx_signature +from ffcx.codegeneration.utils import \ + numba_ufcx_kernel_signature as ufcx_signature -if default_real_type == np.float32: +if PETSc.RealType == np.float32: # type: ignore print("float32 not yet supported for this demo.") exit(0) @@ -139,25 +138,13 @@ def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL): # Prepare a Form with a condensed tabulation kernel -formtype = None -if PETSc.ScalarType == np.float32: # type: ignore - formtype = Form_float32 -elif PETSc.ScalarType == np.float64: # type: ignore - formtype = Form_float64 -elif PETSc.ScalarType == np.complex64: # type: ignore - formtype = Form_complex64 -elif PETSc.ScalarType == np.complex128: # type: ignore - formtype = Form_complex128 -else: - raise RuntimeError(f"Unsupported PETSc ScalarType '{PETSc.ScalarType}'.") # type: ignore - +formtype = form_cpp_class(PETSc.ScalarType) # type: ignore cells = np.arange(msh.topology.index_map(msh.topology.dim).size_local) integrals = {IntegralType.cell: [(-1, tabulate_A.address, cells)]} a_cond = Form(formtype([U._cpp_object, U._cpp_object], integrals, [], [], False, None)) A_cond = assemble_matrix(a_cond, bcs=[bc]) A_cond.assemble() - b = assemble_vector(b1) apply_lifting(b, [a_cond], bcs=[[bc]]) b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE) # type: ignore diff --git a/python/dolfinx/__init__.py b/python/dolfinx/__init__.py index b91d7691f4d..8e789acb561 100644 --- a/python/dolfinx/__init__.py +++ b/python/dolfinx/__init__.py @@ -20,7 +20,8 @@ from dolfinx import common from dolfinx import cpp as _cpp -from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot, utils +from dolfinx import (fem, geometry, graph, io, jit, la, log, mesh, nls, plot, + utils) # Initialise logging from dolfinx.common import (TimingType, git_commit_hash, has_debug, has_kahip, has_parmetis, list_timings, timing) diff --git a/python/dolfinx/fem/__init__.py b/python/dolfinx/fem/__init__.py index 5f2a4e7d4b2..6abf5071f2c 100644 --- a/python/dolfinx/fem/__init__.py +++ b/python/dolfinx/fem/__init__.py @@ -16,7 +16,8 @@ locate_dofs_geometrical, locate_dofs_topological) from dolfinx.fem.dofmap import DofMap from dolfinx.fem.element import CoordinateElement, coordinate_element -from dolfinx.fem.forms import Form, extract_function_spaces, form +from dolfinx.fem.forms import (Form, extract_function_spaces, form, + form_cpp_class) from dolfinx.fem.function import (Constant, ElementMetaData, Expression, Function, FunctionSpace, functionspace) @@ -46,4 +47,4 @@ def create_sparsity_pattern(a: Form): "form", "IntegralType", "create_vector", "locate_dofs_geometrical", "locate_dofs_topological", "extract_function_spaces", "transpose_dofmap", "create_nonmatching_meshes_interpolation_data", - "CoordinateElement", "coordinate_element"] + "CoordinateElement", "coordinate_element", "form_cpp_class"] diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index d78c8f01ca2..877672b24ec 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -82,6 +82,34 @@ def integral_types(self): return self._cpp_object.integral_types +def form_cpp_class(dtype: npt.DTypeLike) -> typing.Union[_cpp.fem.Form_float32, + _cpp.fem.Form_float64, + _cpp.fem.Form_complex64, + _cpp.fem.Form_complex128]: + """Return the wrapped C++ class of a variational form of a specific scalar type. + + Args: + dtype: Scalar type of the required form class. + + Returns: + Wrapped C++ form class of the requested type. + + Note: + This function is for advanced usage, typically when writing + custom kernels using Numba or C. + """ + if dtype == np.float32: + return _cpp.fem.Form_float32 + elif dtype == np.float64: + return _cpp.fem.Form_float64 + elif dtype == np.complex64: + return _cpp.fem.Form_complex64 + elif dtype == np.complex128: + return _cpp.fem.Form_complex128 + else: + raise NotImplementedError(f"Type {dtype} not supported.") + + _ufl_to_dolfinx_domain = {"cell": IntegralType.cell, "exterior_facet": IntegralType.exterior_facet, "interior_facet": IntegralType.interior_facet, @@ -115,16 +143,7 @@ def form(form: typing.Union[ufl.Form, typing.Iterable[ufl.Form]], form_compiler_options = dict() form_compiler_options["scalar_type"] = dtype - if dtype == np.float32: - ftype = _cpp.fem.Form_float32 - elif dtype == np.float64: - ftype = _cpp.fem.Form_float64 - elif dtype == np.complex64: - ftype = _cpp.fem.Form_complex64 - elif dtype == np.complex128: - ftype = _cpp.fem.Form_complex128 - else: - raise NotImplementedError(f"Type {dtype} not supported.") + ftype = form_cpp_class(dtype) def _form(form): """Compile a single UFL form""" diff --git a/python/dolfinx/utils.py b/python/dolfinx/utils.py index e5a6f91d674..af16805fb53 100644 --- a/python/dolfinx/utils.py +++ b/python/dolfinx/utils.py @@ -69,6 +69,7 @@ def set_vals(A: int, """ try: import petsc4py.PETSc as _PETSc + import llvmlite as _llvmlite import numba as _numba _llvmlite.binding.load_library_permanently(str(get_petsc_lib())) @@ -157,10 +158,11 @@ def set_vals(A: int, ffi.from_buffer(rows(data), mode) """ try: + from petsc4py import PETSc as _PETSc + import cffi as _cffi import numba as _numba import numba.core.typing.cffi_utils as _cffi_support - from petsc4py import PETSc as _PETSc # Register complex types _ffi = _cffi.FFI() diff --git a/python/test/unit/fem/test_custom_jit_kernels.py b/python/test/unit/fem/test_custom_jit_kernels.py index 27f6e01bd6f..a2a5bbd5972 100644 --- a/python/test/unit/fem/test_custom_jit_kernels.py +++ b/python/test/unit/fem/test_custom_jit_kernels.py @@ -19,7 +19,7 @@ from dolfinx import TimingType from dolfinx import cpp as _cpp from dolfinx import fem, la, list_timings -from dolfinx.fem import Form, Function, IntegralType, functionspace +from dolfinx.fem import Form, Function, IntegralType, functionspace, form_cpp_class from dolfinx.mesh import create_unit_square import ffcx.codegeneration.utils @@ -78,13 +78,8 @@ def tabulate(b_, w_, c_, coords_, local_index, orientation): return tabulate -@pytest.mark.parametrize("dtype,formtype", [ - (np.float32, _cpp.fem.Form_float32), - (np.float64, _cpp.fem.Form_float64), - (np.complex64, _cpp.fem.Form_complex64), - (np.complex128, _cpp.fem.Form_complex128) -]) -def test_numba_assembly(dtype, formtype): +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) +def test_numba_assembly(dtype): xdtype = np.real(dtype(0)).dtype k2 = tabulate_rank2(dtype, xdtype) k1 = tabulate_rank1(dtype, xdtype) @@ -94,6 +89,7 @@ def test_numba_assembly(dtype, formtype): integrals = {IntegralType.cell: [(-1, k2.address, cells), (12, k2.address, np.arange(0)), (2, k2.address, np.arange(0))]} + formtype = form_cpp_class(dtype) a = Form(formtype([V._cpp_object, V._cpp_object], integrals, [], [], False, None)) integrals = {IntegralType.cell: [(-1, k1.address, cells)]} L = Form(formtype([V._cpp_object], integrals, [], [], False, None)) @@ -111,13 +107,8 @@ def test_numba_assembly(dtype, formtype): list_timings(MPI.COMM_WORLD, [TimingType.wall]) -@pytest.mark.parametrize("dtype,formtype", [ - (np.float32, _cpp.fem.Form_float32), - (np.float64, _cpp.fem.Form_float64), - (np.complex64, _cpp.fem.Form_complex64), - (np.complex128, _cpp.fem.Form_complex128) -]) -def test_coefficient(dtype, formtype): +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) +def test_coefficient(dtype): xdtype = np.real(dtype(0)).dtype k1 = tabulate_rank1_coeff(dtype, xdtype) @@ -130,6 +121,7 @@ def test_coefficient(dtype, formtype): tdim = mesh.topology.dim num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts integrals = {IntegralType.cell: [(1, k1.address, np.arange(num_cells, dtype=np.int32))]} + formtype = form_cpp_class(dtype) L = Form(formtype([V._cpp_object], integrals, [vals._cpp_object], [], False, None)) b = dolfinx.fem.assemble_vector(L)