Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally assemble using multiple threads #625

Merged
merged 3 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions skfem/assembly/form/bilinear_form.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional, Tuple
from threading import Thread
from itertools import product

import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -73,7 +75,10 @@ def _assemble(self,

# initialize COO data structures
sz = ubasis.Nbfun * vbasis.Nbfun * nt
data = np.zeros(sz, dtype=self.dtype)
if self.nthreads > 0:
data = np.zeros((vbasis.Nbfun, ubasis.Nbfun, nt), dtype=self.dtype)
else:
data = np.zeros(sz, dtype=self.dtype)
rows = np.zeros(sz, dtype=np.int64)
cols = np.zeros(sz, dtype=np.int64)

Expand All @@ -84,12 +89,36 @@ def _assemble(self,
nt * (vbasis.Nbfun * j + i + 1))
rows[ixs] = vbasis.element_dofs[i]
cols[ixs] = ubasis.element_dofs[j]
data[ixs] = self._kernel(
ubasis.basis[j],
vbasis.basis[i],
wdict,
dx,
)
if self.nthreads <= 0:
data[ixs] = self._kernel(
ubasis.basis[j],
vbasis.basis[i],
wdict,
dx,
)

if self.nthreads > 0:
# create indices for linear loop over local stiffness matrix
indices = np.array(
[[i, j] for j, i in product(range(ubasis.Nbfun),
range(vbasis.Nbfun))]
)

# split local stiffness matrix elements to threads
threads = [
Thread(
target=self._threaded_kernel,
args=(data, ij, ubasis.basis, vbasis.basis, wdict, dx)
) for ij in np.array_split(indices, self.nthreads, axis=0)
]

# start threads and wait for finishing
for t in threads:
t.start()
for t in threads:
t.join()

data = np.transpose(data, (1, 0, 2)).flatten('C')

return data, rows, cols, (vbasis.N, ubasis.N)

Expand All @@ -114,3 +143,13 @@ def assemble(self, *args, **kwargs) -> csr_matrix:

def _kernel(self, u, v, w, dx):
return np.sum(self.form(*u, *v, w) * dx, axis=1)

def _threaded_kernel(self, data, ix, ubasis, vbasis, wdict, dx):
for ij in ix:
i, j = ij
data[i, j] = self._kernel(
ubasis[j],
vbasis[i],
wdict,
dx,
)
8 changes: 6 additions & 2 deletions skfem/assembly/form/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ class Form:

def __init__(self,
form: Optional[Callable] = None,
dtype: type = np.float64):
dtype: type = np.float64,
nthreads: int = 0):
self.form = form.form if isinstance(form, Form) else form
self.dtype = dtype
self.nthreads = nthreads

def partial(self, *args, **kwargs):
form = deepcopy(self)
Expand All @@ -34,7 +36,9 @@ def partial(self, *args, **kwargs):

def __call__(self, *args):
if self.form is None: # decorate
return type(self)(form=args[0], dtype=self.dtype)
return type(self)(form=args[0],
dtype=self.dtype,
nthreads=self.nthreads)
return self.assemble(self.kernel(*args))

def assemble(self,
Expand Down
33 changes: 32 additions & 1 deletion tests/test_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
from numpy.testing import (assert_equal, assert_almost_equal,
assert_array_almost_equal)

from skfem import BilinearForm, LinearForm, Functional, asm, solve
from skfem.element import (ElementQuad1, ElementQuadS2, ElementHex1,
Expand Down Expand Up @@ -35,6 +36,13 @@ def uv(u, v, w):

B = asm(uv, self.fbasis)

# assemble the same matrix using multiple threads
@BilinearForm(nthreads=2)
def uvt(u, v, w):
return u * v

Bt = asm(uvt, self.fbasis)

@LinearForm
def gv(v, w):
return 1.0 * v
Expand All @@ -45,6 +53,7 @@ def gv(v, w):

self.assertAlmostEqual(ones @ g, self.boundary_area, places=4)
self.assertAlmostEqual(ones @ (B @ ones), self.boundary_area, places=4)
self.assertAlmostEqual(ones @ (Bt @ ones), self.boundary_area, places=4)


class IntegrateOneOverBoundaryS2(IntegrateOneOverBoundaryQ1):
Expand Down Expand Up @@ -434,5 +443,27 @@ def complexfun(v, w):
self.assertAlmostEqual(np.dot(ones, f), 1j * self.interior_area)


class TestThreadedAssembly(TestCase):

def runTest(self):

m = MeshTri().refined()
e = ElementTriP1()
basis = InteriorBasis(m, e)

@BilinearForm
def nonsym(u, v, w):
return u.grad[0] * v

@BilinearForm(nthreads=2)
def threaded_nonsym(u, v, w):
return u.grad[0] * v

assert_almost_equal(
nonsym.assemble(basis).toarray(),
threaded_nonsym.assemble(basis).toarray(),
)


if __name__ == '__main__':
main()