Skip to content

Commit

Permalink
Merge pull request #167 from magnatelee/cholesky-heuristic
Browse files Browse the repository at this point in the history
Simple tiling heuristic for Cholesky factorization
  • Loading branch information
magnatelee authored Jan 7, 2022
2 parents 97b6c7c + 1a15006 commit 2199f8f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
3 changes: 2 additions & 1 deletion cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,4 +1752,5 @@ def compute_strides(shape):
@auto_convert([1])
@shadow_debug("cholesky", [1])
def cholesky(self, src, stacklevel=0, callsite=None):
cholesky(self, src, stacklevel, callsite)
cholesky(self, src, stacklevel=stacklevel + 1, callsite=callsite)
self.trilu(self, 0, True, stacklevel=stacklevel + 1, callsite=callsite)
34 changes: 30 additions & 4 deletions cunumeric/linalg/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,38 @@ def gemm(context, p_output, k, i, lo, hi):
task.execute()


MIN_CHOLESKY_TILE_SIZE = 2048
MIN_CHOLESKY_MATRIX_SIZE = 8192


# TODO: We need a better cost model
def choose_color_shape(runtime, shape):
if runtime.test_mode:
num_tiles = runtime.num_procs * 2
return (num_tiles, num_tiles)
else:
extent = shape[0]
# If there's only one processor or the matrix is too small,
# don't even bother to partition it at all
if runtime.num_procs == 1 or extent <= MIN_CHOLESKY_MATRIX_SIZE:
return (1, 1)

# If the matrix is big enough to warrant partitioning,
# pick the granularity that the tile size is greater than a threshold
num_tiles = runtime.num_procs
max_num_tiles = runtime.num_procs * 4
while (
(extent + num_tiles - 1) // num_tiles > MIN_CHOLESKY_TILE_SIZE
and num_tiles * 2 <= max_num_tiles
):
num_tiles *= 2

return (num_tiles, num_tiles)


def cholesky(output, input, stacklevel=0, callsite=None):
num_procs = output.runtime.num_procs * 2
shape = output.base.shape
color_shape = (num_procs, num_procs)
color_shape = choose_color_shape(output.runtime, shape)
tile_shape = (shape + color_shape - 1) // color_shape
color_shape = (shape + tile_shape - 1) // tile_shape
n = color_shape[0]
Expand All @@ -113,5 +141,3 @@ def cholesky(output, input, stacklevel=0, callsite=None):
for k in range(i + 1, n):
syrk(context, p_output, k, i)
gemm(context, p_output, k, i, k + 1, n)

output.trilu(output, 0, True, stacklevel=stacklevel + 1, callsite=callsite)

0 comments on commit 2199f8f

Please sign in to comment.