diff --git a/src/pyjuice/utils/kernel_launcher.py b/src/pyjuice/utils/kernel_launcher.py index 5b313eb5..6690c890 100644 --- a/src/pyjuice/utils/kernel_launcher.py +++ b/src/pyjuice/utils/kernel_launcher.py @@ -34,6 +34,13 @@ def wrapper(*args, **kwargs): if k in self.constexpr_names: signature_list.append((self.constexpr_names[k], v)) + if "batch_size" in kwargs: + signature_list.append(("batch_size", kwargs["batch_size"])) + + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + signature_list.append((i, id(arg))) + grid0 = grid[0] grid1 = grid[1] if len(grid) > 1 else 1 grid2 = grid[2] if len(grid) > 2 else 1