Skip to content

Commit

Permalink
Merge pull request #77 from kitsuyui/update-jit-tests
Browse files Browse the repository at this point in the history
chore: update tests about torch.jit
  • Loading branch information
kitsuyui authored Mar 21, 2023
2 parents 0c7cba9 + b15b8a5 commit 030072d
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions example/torch/test_torch_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
# https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html


def something_orig(x: torch.Tensor) -> torch.Tensor:
def something(x: torch.Tensor) -> torch.Tensor:
for i in range(512):
x += i
return x


@torch.jit.script
def something_jit(x: torch.Tensor) -> torch.Tensor:
for i in range(512):
x += i
return x
something_jit = torch.jit.script(something)


class SomethingOrig(nn.Module):
Expand All @@ -28,12 +24,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def test_torch_jit_fn() -> None:
x = torch.rand(1)
assert something_orig(x) == something_jit(x)
orig_time = timeit.timeit(lambda: something_orig(x), number=1000)
assert something(x) == something_jit(x)
orig_time = timeit.timeit(lambda: something(x), number=1000)
jit_time = timeit.timeit(lambda: something_jit(x), number=1000)

tobe_ir = """\
def something_jit(x: Tensor) -> Tensor:
def something(x: Tensor) -> Tensor:
x0 = x
for i in range(512):
x0 = torch.add_(x0, i)
Expand Down

0 comments on commit 030072d

Please sign in to comment.