Skip to content

Commit

Permalink
Add flag to disable computation of doflocs (to save memory) (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala authored Mar 19, 2024
1 parent 65b4a17 commit c6d3463
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 24 deletions.
26 changes: 14 additions & 12 deletions skfem/assembly/basis/abstract_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self,
intorder: Optional[int] = None,
quadrature: Optional[Tuple[ndarray, ndarray]] = None,
refdom: Type[Refdom] = Refdom,
dofs: Optional[Dofs] = None):
dofs: Optional[Dofs] = None,
disable_doflocs: bool = False):

if mesh.refdom != elem.refdom:
raise ValueError("Incompatible Mesh and Element.")
Expand All @@ -56,17 +57,18 @@ def __init__(self,
self.dofs = Dofs(mesh, elem) if dofs is None else dofs

# global degree-of-freedom location
try:
doflocs = self.mapping.F(elem.doflocs.T)
self.doflocs = np.zeros((doflocs.shape[0], self.N))

# match mapped dofs and global dof numbering
for itr in range(doflocs.shape[0]):
for jtr in range(self.dofs.element_dofs.shape[0]):
self.doflocs[itr, self.dofs.element_dofs[jtr]] =\
doflocs[itr, :, jtr]
except Exception:
logger.warning("Unable to calculate global DOF locations.")
if not disable_doflocs:
try:
doflocs = self.mapping.F(elem.doflocs.T)
self.doflocs = np.zeros((doflocs.shape[0], self.N))

# match mapped dofs and global dof numbering
for itr in range(doflocs.shape[0]):
for jtr in range(self.dofs.element_dofs.shape[0]):
self.doflocs[itr, self.dofs.element_dofs[jtr]] =\
doflocs[itr, :, jtr]
except Exception:
logger.warning("Unable to calculate global DOF locations.")

self.mesh = mesh
self.elem = elem
Expand Down
8 changes: 7 additions & 1 deletion skfem/assembly/basis/cell_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(self,
intorder: Optional[int] = None,
elements: Optional[Any] = None,
quadrature: Optional[Tuple[ndarray, ndarray]] = None,
dofs: Optional[Dofs] = None):
dofs: Optional[Dofs] = None,
disable_doflocs: bool = False):
"""Combine :class:`~skfem.mesh.Mesh` and
:class:`~skfem.element.Element` into a set of precomputed global basis
functions.
Expand All @@ -70,6 +71,10 @@ def __init__(self,
Optional tuple of quadrature points and weights.
dofs
Optional :class:`~skfem.assembly.Dofs` object.
disable_doflocs
If `True`, the computation of global DOF locations is
disabled. This may save memory on large meshes if DOF
locations are not required.
"""
logger.info("Initializing {}({}, {})".format(type(self).__name__,
Expand All @@ -83,6 +88,7 @@ def __init__(self,
quadrature,
mesh.refdom,
dofs,
disable_doflocs,
)

if elements is None:
Expand Down
8 changes: 7 additions & 1 deletion skfem/assembly/basis/facet_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self,
quadrature: Optional[Tuple[ndarray, ndarray]] = None,
facets: Optional[Any] = None,
dofs: Optional[Dofs] = None,
side: int = 0):
side: int = 0,
disable_doflocs: bool = False):
"""Precomputed global basis on boundary facets.
Parameters
Expand All @@ -51,6 +52,10 @@ def __init__(self,
Optional subset of facet indices.
dofs
Optional :class:`~skfem.assembly.Dofs` object.
disable_doflocs
If `True`, the computation of global DOF locations is
disabled. This may save memory on large meshes if DOF
locations are not required.
"""
typestr = ("{}({}, {})".format(type(self).__name__,
Expand All @@ -65,6 +70,7 @@ def __init__(self,
quadrature,
mesh.brefdom,
dofs,
disable_doflocs,
)

# by default use boundary facets
Expand Down
4 changes: 3 additions & 1 deletion skfem/assembly/basis/interior_facet_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self,
quadrature: Optional[Tuple[ndarray, ndarray]] = None,
facets: Optional[Any] = None,
dofs: Optional[Dofs] = None,
side: int = 0):
side: int = 0,
disable_doflocs: bool = False):
"""Precomputed global basis on interior facets."""

if facets is None:
Expand All @@ -42,4 +43,5 @@ def __init__(self,
facets=facets,
dofs=dofs,
side=side,
disable_doflocs=disable_doflocs,
)
1 change: 0 additions & 1 deletion skfem/assembly/form/coo_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def tolocal(self, basis=None):
if self.local_shape is None:
raise NotImplementedError("Cannot build local matrices if "
"local_shape is not specified.")
assert len(self.local_shape) == 2

local = np.moveaxis(self.data.reshape(self.local_shape + (-1,),
order='C'), -1, 0)
Expand Down
17 changes: 10 additions & 7 deletions tests/test_autodiff.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import pytest
import numpy as np
import jax.numpy as jnp
from numpy.testing import (assert_array_almost_equal,
assert_almost_equal)
from skfem.experimental.autodiff import NonlinearForm
from skfem.experimental.autodiff.helpers import (grad, dot,
ddot, mul,
div, sym_grad,
transpose,
eye, trace)
try:
import jax.numpy as jnp
from skfem.experimental.autodiff import NonlinearForm
from skfem.experimental.autodiff.helpers import (grad, dot,
ddot, mul,
div, sym_grad,
transpose,
eye, trace)
except:
pass
from skfem.assembly import Basis
from skfem.mesh import MeshTri, MeshQuad
from skfem.element import (ElementTriP1, ElementTriP2,
Expand Down
20 changes: 19 additions & 1 deletion tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from skfem import BilinearForm, LinearForm, asm, solve, condense, projection
from skfem.mesh import (Mesh, MeshTri, MeshTet, MeshHex,
MeshQuad, MeshLine1, MeshWedge1)
from skfem.assembly import CellBasis, FacetBasis, Dofs, Functional
from skfem.assembly import (CellBasis, FacetBasis, Dofs, Functional,
InteriorFacetBasis)
from skfem.mapping import MappingIsoparametric
from skfem.element import (ElementVectorH1, ElementTriP2, ElementTriP1,
ElementTetP2, ElementHexS2, ElementHex2,
Expand Down Expand Up @@ -646,3 +647,20 @@ def test_with_elements():
assert basis.mapping == basis_half.mapping
assert basis.quadrature == basis_half.quadrature
assert all(basis_half.tind == basis.mesh.normalize_elements('a'))


def test_disable_doflocs():
mesh = MeshTri().refined(3)
basis = CellBasis(mesh, ElementTriP1())
basisd = CellBasis(mesh, ElementTriP1(), disable_doflocs=True)
fbasis = FacetBasis(mesh, ElementTriP1())
fbasisd = FacetBasis(mesh, ElementTriP1(), disable_doflocs=True)
ifbasis = InteriorFacetBasis(mesh, ElementTriP1())
ifbasisd = InteriorFacetBasis(mesh, ElementTriP1(),
disable_doflocs=True)
assert not hasattr(fbasisd, 'doflocs')
assert hasattr(fbasis, 'doflocs')
assert not hasattr(basisd, 'doflocs')
assert hasattr(basis, 'doflocs')
assert not hasattr(ifbasisd, 'doflocs')
assert hasattr(ifbasis, 'doflocs')

0 comments on commit c6d3463

Please sign in to comment.