Skip to content

Commit

Permalink
update threading related parameter names
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 5, 2024
1 parent fb64a7d commit 6317362
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 69 deletions.
170 changes: 102 additions & 68 deletions quimb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,40 +477,65 @@ def ispos(qob, tol=1e-15):


@njit(nogil=True)
def par_choose_nblocks(size_total, size_blocks, num_threads):
"""Give `size_total` items, a target block size `size_blocks`, and number of threads
`num_threads`, choose the number of blocks to split `size_total` into, the base block
size, and the remainder, for `par_get_block_range`.
def threading_choose_num_blocks(size_total, target_block_size, num_threads):
"""Given `size_total` items, `target_block_size`, and number of threads
`num_threads`, choose the number of blocks to split `size_total` into, the
base block size, and the remainder, used with `threading_get_block_range`.
Parameters
----------
size_total : int
Total number of items to split.
target_block_size : int
Target block size. If positive, blocks will be at least this size. If
negative, blocks will be close to this size.
num_threads : int
Number of threads to split into.
Returns
-------
int, int, int
Number of blocks, base block size, and block remainder.
"""
if num_threads == 1:
# always just 1 block for single thread
Nb = 1
elif size_blocks == -1:
# target num_threads blocks
Nb = num_threads
else:
# target blocks of size size_blocks
Nb = math.ceil(size_total / size_blocks)
if Nb > num_threads:
num_blocks = 1

elif target_block_size < 0:
# target blocks actually close to size target_block_size, for
# cyclically distributing work with potentially varying costs
target_block_size = -target_block_size
num_blocks = math.ceil(size_total / target_block_size)
if num_blocks > num_threads:
# round to nearest multiple of num_threads
Nb = num_threads * round(Nb / num_threads)
num_blocks = num_threads * round(num_blocks / num_threads)

else:
# target blocks at least as big as target_block_size
num_blocks = min(num_threads, round(size_total / num_threads))

base_block_size, remainder = divmod(size_total, Nb)
return Nb, base_block_size, remainder
base_block_size, block_remainder = divmod(size_total, num_blocks)
return num_blocks, base_block_size, block_remainder


@njit(nogil=True)
def par_get_block_range(b, base_block_size, remainder):
start = b * base_block_size + min(b, remainder)
block_size = base_block_size + (1 if b < remainder else 0)
def threading_get_block_range(b, base_block_size, block_remainder):
"""Given block index `b`, base block size `base_block_size`, and remainder
`block_remainder`, return the start and stop indices of the block.
"""
start = b * base_block_size + min(b, block_remainder)
block_size = base_block_size + (1 if b < block_remainder else 0)
stop = start + block_size
return start, stop


def par_maybe_multithread(
fn, *args, size_total, size_blocks, num_threads, **kwargs
def maybe_multithread(
fn, *args, size_total, target_block_size, num_threads, **kwargs
):
if size_total <= size_blocks:
"""Based on the size of the problem, either call `fn` directly or
get a pool and multithread it.
"""
if size_total <= target_block_size:
# don't bother getting pool
fn(*args, **kwargs)
else:
Expand All @@ -525,7 +550,7 @@ def par_maybe_multithread(
*args,
trank=trank,
num_threads=num_threads,
size_blocks=size_blocks,
target_block_size=target_block_size,
**kwargs,
)
for trank in range(num_threads)
Expand Down Expand Up @@ -638,20 +663,21 @@ def _dot_csr_matvec_numba(
out,
trank=0,
num_threads=1,
size_blocks=1024,
target_block_size=-1024,
):
N = vec.size

# total number of blocks
# this thread processes every num_threads'th block: the logic here is you want to
# process a large enough block of contiguous rows to make the memory access
# efficient, but also cyclically distribute the rows which may have varying
# sparsity on a larger scale
Nb, base_block_size, remainder = par_choose_nblocks(
N, size_blocks, num_threads
# this thread processes every num_threads'th block: the logic here is you
# want to process a large enough block of contiguous rows to make the
# memory access efficient, but also cyclically distribute the rows which
# may have varying sparsity on a larger scale
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, Nb, num_threads):
istart, istop = par_get_block_range(b, base_block_size, remainder)
for b in range(trank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)

for i in range(istart, istop):
isum = 0.0
Expand All @@ -660,7 +686,7 @@ def _dot_csr_matvec_numba(
out[i] = isum


def par_dot_csr_matvec(A, x, size_blocks=1024, num_threads=None):
def par_dot_csr_matvec(A, x, target_block_size=-1024, num_threads=None):
"""Parallel sparse csr-matrix vector dot product.
Parameters
Expand All @@ -669,7 +695,7 @@ def par_dot_csr_matvec(A, x, size_blocks=1024, num_threads=None):
Operator.
x : dense vector
Vector.
size_blocks : int, optional
target_block_size : int, optional
The target block size (number of rows) for each thread if parallel.
num_threads : int, optional
Number of threads to use. If None, will use the default number of
Expand All @@ -687,15 +713,15 @@ def par_dot_csr_matvec(A, x, size_blocks=1024, num_threads=None):
"""
y = np.empty(x.size, common_type(A, x))

par_maybe_multithread(
maybe_multithread(
_dot_csr_matvec_numba,
A.data,
A.indptr,
A.indices,
x.ravel(),
y,
size_total=x.size,
size_blocks=size_blocks,
target_block_size=target_block_size,
num_threads=num_threads,
)

Expand Down Expand Up @@ -760,35 +786,37 @@ def rdot(a, b): # pragma: no cover

@njit(nogil=True)
def _l_diag_dot_dense_par(
l, A, out, trank=0, num_threads=1, size_blocks=128
l, A, out, trank=0, num_threads=1, target_block_size=128
): # pragma: no cover
N, M = A.shape
Nb, base_block_size, remainder = par_choose_nblocks(
N, size_blocks, num_threads
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, Nb, num_threads):
istart, istop = par_get_block_range(b, base_block_size, remainder)
for b in range(trank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
li = l[i]
for j in range(M):
out[i, j] = li * A[i, j]


@ensure_qarray
def l_diag_dot_dense(diag, mat, num_threads=None, size_blocks=128):
def l_diag_dot_dense(diag, mat, num_threads=None, target_block_size=128):
"""Dot product of diagonal matrix (with only diagonal supplied) and dense
matrix.
"""
diag = diag.ravel()
out = np.empty_like(mat, dtype=common_type(diag, mat))

par_maybe_multithread(
maybe_multithread(
_l_diag_dot_dense_par,
diag,
mat,
out,
size_total=diag.size,
size_blocks=size_blocks,
target_block_size=target_block_size,
num_threads=num_threads,
)

Expand Down Expand Up @@ -825,33 +853,35 @@ def ldmul(diag, mat):

@njit(nogil=True)
def _r_diag_dot_dense_par(
A, l, out, trank=0, num_threads=1, size_blocks=128
A, l, out, trank=0, num_threads=1, target_block_size=128
): # pragma: no cover
N, M = A.shape
Nb, base_block_size, remainder = par_choose_nblocks(
N, size_blocks, num_threads
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, Nb, num_threads):
istart, istop = par_get_block_range(b, base_block_size, remainder)
for b in range(trank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
for j in range(M):
out[i, j] = A[i, j] * l[j]


@ensure_qarray
def r_diag_dot_dense(mat, diag, num_threads=None, size_blocks=128):
def r_diag_dot_dense(mat, diag, num_threads=None, target_block_size=128):
"""Dot product of dense matrix and digonal matrix (with only diagonal
supplied).
"""
diag = diag.ravel()
out = np.empty_like(mat, dtype=common_type(diag, mat))
par_maybe_multithread(
maybe_multithread(
_r_diag_dot_dense_par,
mat,
diag,
out,
size_total=diag.size,
size_blocks=size_blocks,
target_block_size=target_block_size,
num_threads=num_threads,
)
return out
Expand Down Expand Up @@ -888,35 +918,37 @@ def rdmul(mat, diag):

@njit(nogil=True)
def _outer_par(
x, y, out, m, n, trank=0, num_threads=1, size_blocks=128
x, y, out, m, n, trank=0, num_threads=1, target_block_size=128
): # pragma: no cover
Nb, base_block_size, remainder = par_choose_nblocks(
m, size_blocks, num_threads
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
m, target_block_size, num_threads
)
for b in range(trank, Nb, num_threads):
istart, istop = par_get_block_range(b, base_block_size, remainder)
for b in range(trank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)

for i in range(istart, istop):
for j in range(n):
out[i, j] = x[i] * y[j]


@ensure_qarray
def outer(a, b, num_threads=None, size_blocks=128):
def outer(a, b, num_threads=None, target_block_size=128):
"""Outer product between two vectors (no conjugation)."""
a = a.ravel()
b = b.ravel()
m, n = a.size, b.size
out = np.empty((m, n), dtype=common_type(a, b))
par_maybe_multithread(
maybe_multithread(
_outer_par,
a,
b,
out,
m,
n,
size_total=m,
size_blocks=size_blocks,
target_block_size=target_block_size,
num_threads=num_threads,
)
return out
Expand Down Expand Up @@ -944,14 +976,16 @@ def _kron_dense_numba(
q,
trank=0,
num_threads=1,
size_blocks=128,
target_block_size=128,
): # pragma: no cover
N = m * p
Nb, base_block_size, remainder = par_choose_nblocks(
N, size_blocks, num_threads
num_blocks, base_block_size, block_remainder = threading_choose_num_blocks(
N, target_block_size, num_threads
)
for b in range(trank, Nb, num_threads):
istart, istop = par_get_block_range(b, base_block_size, remainder)
for b in range(trank, num_blocks, num_threads):
istart, istop = threading_get_block_range(
b, base_block_size, block_remainder
)
for i in range(istart, istop):
ia, ib = divmod(i, p)
i = p * ia + ib
Expand All @@ -963,11 +997,11 @@ def _kron_dense_numba(


@ensure_qarray
def kron_dense(a, b, num_threads=None, size_blocks=128):
def kron_dense(a, b, num_threads=None, target_block_size=128):
m, n = a.shape
p, q = b.shape
out = np.empty((m * p, n * q), dtype=common_type(a, b))
par_maybe_multithread(
maybe_multithread(
_kron_dense_numba,
a,
b,
Expand All @@ -977,7 +1011,7 @@ def kron_dense(a, b, num_threads=None, size_blocks=128):
p,
q,
size_total=m * p,
size_blocks=size_blocks,
target_block_size=target_block_size,
num_threads=num_threads,
)
return out
Expand Down
2 changes: 1 addition & 1 deletion tests/test_accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_small(self):
class TestKron:
@mark.parametrize("big", [False, True])
def test_kron_dense(self, mat_d, mat_d2, big):
x = kron_dense(mat_d, mat_d2, size_blocks=1 if big else _TEST_SZ)
x = kron_dense(mat_d, mat_d2, target_block_size=1 if big else _TEST_SZ)
assert mat_d.shape == (_TEST_SZ, _TEST_SZ)
assert mat_d2.shape == (_TEST_SZ, _TEST_SZ)
xn = np.kron(mat_d, mat_d2)
Expand Down

0 comments on commit 6317362

Please sign in to comment.