Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logaddexp for ONNX export #1158

Merged
merged 1 commit into from
Jul 2, 2023

Conversation

csukuangfj
Copy link
Collaborator

See also #1157 (comment)

@danpovey
We should actually test whether torch.logaddexp works in jit tracing mode. I think the if-statement was probably added when we had some kind of custom kernel there that wouldn't work in tracing mode.


torch.logaddexp() works for both torch.jit.script() and torch.jit.trace(). However, it causes the following errors for torch.onnx.export()

# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.


This PR makes it possible to use torch.logaddexp() for torch.jit.script() and torch.jit.trace() but use an
alternative for torch.onnx.export().


The following code

#!/usr/bin/env python3

import torch


class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        if torch.onnx.is_in_onnx_export():
            max_value = torch.max(x, y)
            diff = torch.abs(x - y)
            return max_value + torch.log1p(torch.exp(-diff))
        else:
            return torch.logaddexp(x, y)


def main():
    f = Foo()
    x = torch.rand(3)
    torch.jit.script(f)  # cause errors
    torch.jit.trace(f, (x, x))
    torch.onnx.export(f, (x, x), "a.onnx")


if __name__ == "__main__":
    main()

throws the following error for torch.jit.script(). That is why we put torch.jit.is_scripting() as the first if statement in this pull-request.

Traceback (most recent call last):
  File "./xxx.py", line 28, in <module>
    main()
  File "./xxx.py", line 22, in main
    torch.jit.script(f)  # cause errors
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/jit/_script.py", line 1286, in script
    return torch.jit._recursive.create_script_module(
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/jit/_recursive.py", line 476, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/jit/_recursive.py", line 393, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/jit/_recursive.py", line 863, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/jit/_script.py", line 1343, in script
    fn = torch._C._jit_script_compile(
RuntimeError:
attribute lookup is not defined on python value of type '_InternalGlobals':
  File "/Users/fangjun/py38/lib/python3.8/site-packages/torch/onnx/utils.py", line 69
def is_in_onnx_export() -> bool:
    """Returns whether it is in the middle of ONNX export."""
    return GLOBALS.in_onnx_export
           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'is_in_onnx_export' is being compiled since it was called from 'Foo.forward'
  File "xxx.py", line 11
    def forward(self, x, y):
        if torch.onnx.is_in_onnx_export():
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            max_value = torch.max(x, y)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~
            diff = torch.abs(x - y)
            ~~~~~~~~~~~~~~~~~~~~~~~
            return max_value + torch.log1p(torch.exp(-diff))
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        else:
        ~~~~~
            return torch.logaddexp(x, y)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

Copy link
Collaborator

@danpovey danpovey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

@csukuangfj csukuangfj merged commit c3e23ec into k2-fsa:master Jul 2, 2023
@csukuangfj csukuangfj deleted the fix-onnx-zipformer2 branch July 28, 2023 02:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants