Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify access to wrapped C++ form classes #2977

Merged
merged 6 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions python/demo/demo_static-condensation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,18 @@

import ufl
from basix.ufl import element
from dolfinx import default_real_type, geometry
from dolfinx.cpp.fem import (Form_complex64, Form_complex128, Form_float32,
Form_float64)
from dolfinx import geometry
from dolfinx.fem import (Form, Function, IntegralType, dirichletbc, form,
functionspace, locate_dofs_topological)
functionspace, locate_dofs_topological, form_cpp_class)
from dolfinx.fem.petsc import (apply_lifting, assemble_matrix, assemble_vector,
set_bc)
from dolfinx.io import XDMFFile
from dolfinx.jit import ffcx_jit
from dolfinx.mesh import locate_entities_boundary, meshtags
from ffcx.codegeneration.utils import numba_ufcx_kernel_signature as ufcx_signature
from ffcx.codegeneration.utils import \
numba_ufcx_kernel_signature as ufcx_signature

if default_real_type == np.float32:
if PETSc.RealType == np.float32: # type: ignore
print("float32 not yet supported for this demo.")
exit(0)

Expand Down Expand Up @@ -139,25 +138,13 @@ def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL):


# Prepare a Form with a condensed tabulation kernel
formtype = None
if PETSc.ScalarType == np.float32: # type: ignore
formtype = Form_float32
elif PETSc.ScalarType == np.float64: # type: ignore
formtype = Form_float64
elif PETSc.ScalarType == np.complex64: # type: ignore
formtype = Form_complex64
elif PETSc.ScalarType == np.complex128: # type: ignore
formtype = Form_complex128
else:
raise RuntimeError(f"Unsupported PETSc ScalarType '{PETSc.ScalarType}'.") # type: ignore

formtype = form_cpp_class(PETSc.ScalarType) # type: ignore
cells = np.arange(msh.topology.index_map(msh.topology.dim).size_local)
integrals = {IntegralType.cell: [(-1, tabulate_A.address, cells)]}
a_cond = Form(formtype([U._cpp_object, U._cpp_object], integrals, [], [], False, None))

A_cond = assemble_matrix(a_cond, bcs=[bc])
A_cond.assemble()

b = assemble_vector(b1)
apply_lifting(b, [a_cond], bcs=[[bc]])
b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion python/dolfinx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from dolfinx import common
from dolfinx import cpp as _cpp
from dolfinx import fem, geometry, graph, io, jit, la, log, mesh, nls, plot, utils
from dolfinx import (fem, geometry, graph, io, jit, la, log, mesh, nls, plot,
utils)
# Initialise logging
from dolfinx.common import (TimingType, git_commit_hash, has_debug, has_kahip,
has_parmetis, list_timings, timing)
Expand Down
5 changes: 3 additions & 2 deletions python/dolfinx/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
locate_dofs_geometrical, locate_dofs_topological)
from dolfinx.fem.dofmap import DofMap
from dolfinx.fem.element import CoordinateElement, coordinate_element
from dolfinx.fem.forms import Form, extract_function_spaces, form
from dolfinx.fem.forms import (Form, extract_function_spaces, form,
form_cpp_class)
from dolfinx.fem.function import (Constant, ElementMetaData, Expression,
Function, FunctionSpace, functionspace)

Expand Down Expand Up @@ -46,4 +47,4 @@ def create_sparsity_pattern(a: Form):
"form", "IntegralType", "create_vector",
"locate_dofs_geometrical", "locate_dofs_topological",
"extract_function_spaces", "transpose_dofmap", "create_nonmatching_meshes_interpolation_data",
"CoordinateElement", "coordinate_element"]
"CoordinateElement", "coordinate_element", "form_cpp_class"]
39 changes: 29 additions & 10 deletions python/dolfinx/fem/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,34 @@ def integral_types(self):
return self._cpp_object.integral_types


def form_cpp_class(dtype: npt.DTypeLike) -> typing.Union[_cpp.fem.Form_float32,
_cpp.fem.Form_float64,
_cpp.fem.Form_complex64,
_cpp.fem.Form_complex128]:
"""Return the wrapped C++ class of a variational form of a specific scalar type.

Args:
dtype: Scalar type of the required form class.

Returns:
Wrapped C++ form class of the requested type.

Note:
This function is for advanced usage, typically when writing
custom kernels using Numba or C.
"""
if dtype == np.float32:
return _cpp.fem.Form_float32
elif dtype == np.float64:
return _cpp.fem.Form_float64
elif dtype == np.complex64:
return _cpp.fem.Form_complex64
elif dtype == np.complex128:
return _cpp.fem.Form_complex128
else:
raise NotImplementedError(f"Type {dtype} not supported.")


_ufl_to_dolfinx_domain = {"cell": IntegralType.cell,
"exterior_facet": IntegralType.exterior_facet,
"interior_facet": IntegralType.interior_facet,
Expand Down Expand Up @@ -115,16 +143,7 @@ def form(form: typing.Union[ufl.Form, typing.Iterable[ufl.Form]],
form_compiler_options = dict()

form_compiler_options["scalar_type"] = dtype
if dtype == np.float32:
ftype = _cpp.fem.Form_float32
elif dtype == np.float64:
ftype = _cpp.fem.Form_float64
elif dtype == np.complex64:
ftype = _cpp.fem.Form_complex64
elif dtype == np.complex128:
ftype = _cpp.fem.Form_complex128
else:
raise NotImplementedError(f"Type {dtype} not supported.")
ftype = form_cpp_class(dtype)

def _form(form):
"""Compile a single UFL form"""
Expand Down
4 changes: 3 additions & 1 deletion python/dolfinx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def set_vals(A: int,
"""
try:
import petsc4py.PETSc as _PETSc

import llvmlite as _llvmlite
import numba as _numba
_llvmlite.binding.load_library_permanently(str(get_petsc_lib()))
Expand Down Expand Up @@ -157,10 +158,11 @@ def set_vals(A: int,
ffi.from_buffer(rows(data), mode)
"""
try:
from petsc4py import PETSc as _PETSc

import cffi as _cffi
import numba as _numba
import numba.core.typing.cffi_utils as _cffi_support
from petsc4py import PETSc as _PETSc

# Register complex types
_ffi = _cffi.FFI()
Expand Down
22 changes: 7 additions & 15 deletions python/test/unit/fem/test_custom_jit_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dolfinx import TimingType
from dolfinx import cpp as _cpp
from dolfinx import fem, la, list_timings
from dolfinx.fem import Form, Function, IntegralType, functionspace
from dolfinx.fem import Form, Function, IntegralType, functionspace, form_cpp_class
from dolfinx.mesh import create_unit_square
import ffcx.codegeneration.utils

Expand Down Expand Up @@ -78,13 +78,8 @@ def tabulate(b_, w_, c_, coords_, local_index, orientation):
return tabulate


@pytest.mark.parametrize("dtype,formtype", [
(np.float32, _cpp.fem.Form_float32),
(np.float64, _cpp.fem.Form_float64),
(np.complex64, _cpp.fem.Form_complex64),
(np.complex128, _cpp.fem.Form_complex128)
])
def test_numba_assembly(dtype, formtype):
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128])
def test_numba_assembly(dtype):
xdtype = np.real(dtype(0)).dtype
k2 = tabulate_rank2(dtype, xdtype)
k1 = tabulate_rank1(dtype, xdtype)
Expand All @@ -94,6 +89,7 @@ def test_numba_assembly(dtype, formtype):
integrals = {IntegralType.cell: [(-1, k2.address, cells),
(12, k2.address, np.arange(0)),
(2, k2.address, np.arange(0))]}
formtype = form_cpp_class(dtype)
a = Form(formtype([V._cpp_object, V._cpp_object], integrals, [], [], False, None))
integrals = {IntegralType.cell: [(-1, k1.address, cells)]}
L = Form(formtype([V._cpp_object], integrals, [], [], False, None))
Expand All @@ -111,13 +107,8 @@ def test_numba_assembly(dtype, formtype):
list_timings(MPI.COMM_WORLD, [TimingType.wall])


@pytest.mark.parametrize("dtype,formtype", [
(np.float32, _cpp.fem.Form_float32),
(np.float64, _cpp.fem.Form_float64),
(np.complex64, _cpp.fem.Form_complex64),
(np.complex128, _cpp.fem.Form_complex128)
])
def test_coefficient(dtype, formtype):
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128])
def test_coefficient(dtype):
xdtype = np.real(dtype(0)).dtype
k1 = tabulate_rank1_coeff(dtype, xdtype)

Expand All @@ -130,6 +121,7 @@ def test_coefficient(dtype, formtype):
tdim = mesh.topology.dim
num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts
integrals = {IntegralType.cell: [(1, k1.address, np.arange(num_cells, dtype=np.int32))]}
formtype = form_cpp_class(dtype)
L = Form(formtype([V._cpp_object], integrals, [vals._cpp_object], [], False, None))

b = dolfinx.fem.assemble_vector(L)
Expand Down
Loading