Skip to content

Commit

Permalink
check tensor address in the custom kernel launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 9, 2024
1 parent 2a3082a commit ad50984
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/pyjuice/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def wrapper(*args, **kwargs):
if k in self.constexpr_names:
signature_list.append((self.constexpr_names[k], v))

if "batch_size" in kwargs:
signature_list.append(("batch_size", kwargs["batch_size"]))

for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
signature_list.append((i, id(arg)))

grid0 = grid[0]
grid1 = grid[1] if len(grid) > 1 else 1
grid2 = grid[2] if len(grid) > 2 else 1
Expand Down

0 comments on commit ad50984

Please sign in to comment.