Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Cannot trace custom ops #13564

Closed
t-vi opened this issue Nov 5, 2018 · 0 comments
Closed

[JIT] Cannot trace custom ops #13564

t-vi opened this issue Nov 5, 2018 · 0 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@t-vi
Copy link
Collaborator

t-vi commented Nov 5, 2018

🐛 Bug

Custom ops can be used in JIT script, but not in tracing.

To Reproduce

import os
import torch
import torch.jit
csrc = """
#include <torch/extension.h>
#include <torch/script.h>

using namespace at;

Tensor test(const Tensor& inp) {
  return inp * 2;
}

static auto registry =
  torch::jit::RegisterOperators()
    .op("mytest::test", &test);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("test", &test, "super test!");
}
"""

import torch.utils.cpp_extension

ext = torch.utils.cpp_extension.load_inline("test", [csrc], verbose=True,
                                            extra_ldflags=['-ltorch','-lcaffe2',
                                                           '-L'+os.path.join(os.path.dirname(torch._C.__file__), 'lib') ])
torch.ops.load_library(ext.__file__)

t = torch.randn(5)
print(torch.ops.mytest.test(t)) # works

@torch.jit.script
def test_wrapper(t):
    return torch.ops.mytest.test(t)

print (torch.jit.trace(test_wrapper, (t,)))  # works, too


print (torch.jit.trace(torch.ops.mytest.test, (t,)))  # should work!

Expected behavior

Everything to work perfectly!

Environment

Today's master

Context

I noticed this with a formidable JIT challenge (@fmassa ):
facebookresearch/maskrcnn-benchmark#27 (comment)

@goldsborough : This is as mentioned on the slack.

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Nov 5, 2018
@goldsborough goldsborough self-assigned this Nov 5, 2018
driazati pushed a commit to driazati/pytorch that referenced this issue Nov 7, 2018
Summary:
Due to a logic bug, tracing is broken for custom ops. Unfortunately, there also weren't any tests for tracing custom ops.

The fix is a single line change of moving `pop(stack, std::get<Is>(arguments)...);` before `node = getTracedNode<Is...>(schema, arguments);`. Other changes are added tests and improved commenting/formatting.

Fixes pytorch#13564

CC The controller you requested could not be found. fmassa

zdevito
Pull Request resolved: pytorch#13654

Differential Revision: D12952887

Pulled By: goldsborough

fbshipit-source-id: 87d256576f787c58e8d8f5c13a0fecd0ec62a602
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants