Skip to content

Commit

Permalink
relax the criterion on num_edges for using the "sparse" kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 4, 2024
1 parent 8f1498b commit 5537539
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor,
batch_size = node_mars.size(1)
BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size)

assert num_edges <= 16384, "The sparse forward kernel only support nodes with # edges smaller than 16384."
# assert num_edges <= 16384, "The sparse forward kernel only support nodes with # edges smaller than 16384."

if triton.cdiv(layer_n_nodes, self.block_size) <= 2048:
BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1)
Expand Down Expand Up @@ -1083,8 +1083,10 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
elif self.block_size * batch_size < 32:
# Advantage of block-sparse processing is diminishing
mode = self.SPARSE
else:
elif num_edges <= 32768:
mode = self.BLOCK_SPARSE
else:
mode = self.SPARSE

if mode == self.BLOCK_SPARSE:
self._backward_block_sparse(
Expand Down Expand Up @@ -2320,8 +2322,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten
batch_size = node_mars.size(1)
BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size)

assert num_edges <= 16384, "The sparse backward kernel only support nodes with # edges smaller than 16384."

# assert num_edges <= 16384, "The sparse backward kernel only support nodes with # edges smaller than 16384."

if num_edges <= 1024:
BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1)
Expand Down

0 comments on commit 5537539

Please sign in to comment.