Skip to content

Commit

Permalink
sum layer large sparse forward: handle cases with very high grid size
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 14, 2024
1 parent 4f84613 commit 745c6a3
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,12 +858,12 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids,
# @triton.jit
@FastJITFunction
def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, pids,
local_ids, batch_size, num_nodes, partial_eval: tl.constexpr, num_edges: tl.constexpr,
local_ids, batch_size, num_nodes, pid_m_offset, partial_eval: tl.constexpr, num_edges: tl.constexpr,
BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):

pid_b = tl.program_id(axis = 0) # ID of size-`BLOCK_B` batches
pid_m = tl.program_id(axis = 1) # ID of size-`TILE_SIZE_M` nodes
pid_m = tl.program_id(axis = 1) + pid_m_offset # ID of size-`TILE_SIZE_M` nodes

offs_m = tl.arange(0, TILE_SIZE_M) + pid_m * TILE_SIZE_M
mask_m = offs_m < num_nodes
Expand Down Expand Up @@ -961,22 +961,47 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor,

grid = (triton.cdiv(batch_size, BLOCK_B), triton.cdiv(layer_n_nodes, TILE_SIZE_M))

self._fw_triton_large_sparse_kernel[grid](
node_mars = node_mars,
element_mars = element_mars,
params = params,
nids = nids,
cids = cids,
pids = pids,
local_ids = local_ids,
batch_size = batch_size,
num_nodes = layer_n_nodes,
partial_eval = partial_eval,
num_edges = num_edges,
BLOCK_B = BLOCK_B,
TILE_SIZE_M = TILE_SIZE_M,
GROUP_SIZE_M = GROUP_SIZE_M
)
if grid[1] <= 32768:
self._fw_triton_large_sparse_kernel[grid](
node_mars = node_mars,
element_mars = element_mars,
params = params,
nids = nids,
cids = cids,
pids = pids,
local_ids = local_ids,
batch_size = batch_size,
num_nodes = layer_n_nodes,
pid_m_offset = 0,
partial_eval = partial_eval,
num_edges = num_edges,
BLOCK_B = BLOCK_B,
TILE_SIZE_M = TILE_SIZE_M,
GROUP_SIZE_M = GROUP_SIZE_M
)
else:
for pid_m_start in range(0, grid[1], 32768):

pid_m_end = min(pid_m_start + 32768, grid[1])
small_grid = (grid[0], pid_m_end - pid_m_start)

self._fw_triton_large_sparse_kernel[small_grid](
node_mars = node_mars,
element_mars = element_mars,
params = params,
nids = nids,
cids = cids,
pids = pids,
local_ids = local_ids,
batch_size = batch_size,
num_nodes = layer_n_nodes,
pid_m_offset = pid_m_start,
partial_eval = partial_eval,
num_edges = num_edges,
BLOCK_B = BLOCK_B,
TILE_SIZE_M = TILE_SIZE_M,
GROUP_SIZE_M = GROUP_SIZE_M
)

return None

Expand Down

0 comments on commit 745c6a3

Please sign in to comment.