From bb9aca62de7e9885bcc52411411694d4e2b8025c Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sat, 13 Jan 2024 05:52:55 +0800 Subject: [PATCH] fix backward kernel selection --- src/pyjuice/layer/sum_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 0a60e39b..6f1d8701 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -976,7 +976,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, elif params.dim() == 1 and self.group_size >= 16 and num_edges >= 16 and batch_size >= 16: # In this case, we should definitely use the block-sparse implementation mode = self.BLOCK_SPARSE - elif (cs_group_size == 1 or self.group_size == 1) and num_edges < 16384: + elif (cs_group_size == 1 or self.group_size == 1) and num_edges <= 32768: # In this case, we should definitely use the sparse implementation mode = self.SPARSE elif self.group_size * batch_size < 32: