-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save/load TorchScript object in test #1446
Conversation
Add dump step to torchscript tests
@SplitInfinity can you take a look? |
"""Implements test for `functinoal` modul that are performed for different devices""" | ||
def _assert_consistency(self, func, tensor, shape_only=False): | ||
tensor = tensor.to(device=self.device, dtype=self.dtype) | ||
|
||
ts_func = torch.jit.script(func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't remove this and continue checking scripted_fn(input) == fn(input)
where scripted_fn = torch.jit.script(fn)
Synced with Meghan offline:
As mentioned in the comment below, torch.jit.load
returns a scripted module, so the current test is ok, but I think it would be nice to move this code to a helper function that takes in a function/module and returns a scripted module, with a note (see the comment below) indicating what is exactly happening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the above. To clarify, torch.jit.load
returns a scripted module with a forward
method equivalent to the function that was serialized. The caller does not notice the difference because __call__
for the module calls forward
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I have a mixed feeling about introducing another helper function.
First, this method IS the helper function used by the actual tests.
Secondly, my view is that if it's something that needs explanation with comment, keeping it along side of tests make it easier to follow the logic of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment, you duplicate the same code logic at different places, which I think strongly suggests that we should just have one helper function that can be used in all the tests.
ts_func = torch.jit.script(func) | ||
path = self.get_temp_path('func.zip') | ||
torch.jit.script(func).save(path) | ||
ts_func = torch.jit.load(path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This returns a scripted module whose forward is the function you saved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks good! I still think we should add a helper function to avoid replicating the same code logic at different sites, but won't block this PR just for that.
@anjali411 Thanks for the suggestion. Let me move on with the changes for now. The repeated code part, I have a different idea to resolve it. |
This PR adds step to save/load TorchScript object while testing TorchScript compatibility.
Part of #1337