diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 0a60e39b..6f1d8701 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -976,7 +976,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, elif params.dim() == 1 and self.group_size >= 16 and num_edges >= 16 and batch_size >= 16: # In this case, we should definitely use the block-sparse implementation mode = self.BLOCK_SPARSE - elif (cs_group_size == 1 or self.group_size == 1) and num_edges < 16384: + elif (cs_group_size == 1 or self.group_size == 1) and num_edges <= 32768: # In this case, we should definitely use the sparse implementation mode = self.SPARSE elif self.group_size * batch_size < 32: