Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save/load TorchScript object in test #1446

Merged
merged 2 commits into from
Apr 14, 2021

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented Apr 9, 2021

This PR adds step to save/load TorchScript object while testing TorchScript compatibility.

Part of #1337

Add dump step to torchscript tests
@anjali411
Copy link

@SplitInfinity can you take a look?

"""Implements test for `functinoal` modul that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)

ts_func = torch.jit.script(func)
Copy link

@anjali411 anjali411 Apr 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't remove this and continue checking scripted_fn(input) == fn(input) where scripted_fn = torch.jit.script(fn)
Synced with Meghan offline:
As mentioned in the comment below, torch.jit.load returns a scripted module, so the current test is ok, but I think it would be nice to move this code to a helper function that takes in a function/module and returns a scripted module, with a note (see the comment below) indicating what is exactly happening.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the above. To clarify, torch.jit.load returns a scripted module with a forward method equivalent to the function that was serialized. The caller does not notice the difference because __call__ for the module calls forward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I have a mixed feeling about introducing another helper function.

First, this method IS the helper function used by the actual tests.
Secondly, my view is that if it's something that needs explanation with comment, keeping it along side of tests make it easier to follow the logic of tests.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, you duplicate the same code logic at different places, which I think strongly suggests that we should just have one helper function that can be used in all the tests.

ts_func = torch.jit.script(func)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns a scripted module whose forward is the function you saved.

Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks good! I still think we should add a helper function to avoid replicating the same code logic at different sites, but won't block this PR just for that.

@mthrok
Copy link
Collaborator Author

mthrok commented Apr 14, 2021

@anjali411 Thanks for the suggestion. Let me move on with the changes for now. The repeated code part, I have a different idea to resolve it.

@mthrok mthrok merged commit 5c696b5 into pytorch:master Apr 14, 2021
@mthrok mthrok deleted the test-dump-torchscript branch April 14, 2021 16:37
carolineechen pushed a commit to carolineechen/audio that referenced this pull request Apr 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants