-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Will maskrcnn-benchmark support torch.jit.trace or torch.jit.script mode in the nearly future? #27
Comments
There are currently some python functionality in this codebase that is not supported by Currently, you can trace almost all the model, except the custom C++ layers. Once we add support for those missing C++ layers by registering them as torch ops, I believe tracing should work without issues for same-sized images. I'll look into registering the C++ layers into the torch ops |
@fmassa the C++ layers can be registered with the JIT, see @goldsborough's slides from DevCon |
Yes, but I believe that it currently requires some extra code that follows a different codepath. |
@fmassa Thanks for your concern,if there is a way to trace the maskrcnn or any userful information,plesea let us know. |
I'll look into adding tracing support for the custom ops early this week, it should not be hard. I'll update on the issue once it's done |
Awesome! |
I am also interested by this feature! |
@fmassa Do you have added tracing support for the custom ops? Thanks. |
So I did look at this in some depth and continue to do so, here is a bit of a progress report for discussion. I'm also happy to share a branch with my code, but the code is even more "stream of consciousness" than this write-up. Goal and and planMy goal is to be able to detect in single images of a fixed size (known during tracing) in C++ as close as possibe to the "load traced model in C++" example. My first step is to get something that
My findings so farC++ bits / Custom Ops
#include <torch/script.h>
...
static auto registry =
torch::jit::RegisterOperators()
.op("maskrcnn_benchmark::nms", &nms)
.op("maskrcnn_benchmark::roi_align_forward(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", &ROIAlign_forward);
import torch
nms = torch.ops.maskrcnn_benchmark.nms Easy. However, I could not trace the resulting nms. @torch.jit.script
def nms_fixed_thresh1(dets, scores, th: float=coco_demo.model.rpn.box_selector_test.nms_thresh):
return maskrcnn_benchmark.layers.nms(dets, scores, th)
@torch.jit.script
def nms_fixed_thresh(dets, scores):
return nms_fixed_thresh1(dets, scores) Now we can trace A similar wrapping trick was needed for roi align forward, I put that in the layer (where all the constants are parameters, so it's natural). I did change some lists to tuples to make the jit happier. Custom bookkeeping types (boxlist oh oh)The jit isn't very fond of the boxlist things. Where it works, a minimal fix is to "unpack" the parameters of functions, assuming that all Tensors are arguments and all others are constants. That works reasonably well when operating on the same input again in traced mode. It remains to be seen if we run into generalization problems. To facilitate that, I added two methods to bounding_box: # note: _get_tensors/_set_tensors only work if the keys don't change in between!
def _get_tensors(self):
return (self.bbox,)+tuple(f for f in (self.get_field(field) for field in sorted(self.fields())) if isinstance(f, torch.Tensor))
def _set_tensors(self, ts):
self.bbox = ts[0]
for i, f in enumerate(sorted(self.fields())):
if isinstance(self.extra_fields[f], torch.Tensor):
self.extra_fields[f] = ts[1 + i] and there is some wrapper code. Some things that don't work well with tracing/scriptingThe pred_boxes = torch.zeros_like(rel_codes)
# x1
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
# y1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
# x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
# y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 The jit doesn't love that, so I used: pred_boxes = torch.stack([pred_ctr_x - 0.5 * pred_w,
pred_ctr_y - 0.5 * pred_h,
pred_ctr_x + 0.5 * pred_w - 1,
pred_ctr_y + 0.5 * pred_h - 1], 2).view(*rel_codes.shape) Similarly the pooling over several levels in the
There still is a problem (a riddle?) around script wanting to pass a tensor list as a list of tensors and not a tuple, but tracing not accepting lists, I will have to sort that out. Maybe one could convince JIT people to allow passing tuples of tensors where the JIT wants lists of tensors. Mask compositionThis uses PIL at the moment, it'll be replaced. @fmassa has this for GPU, but I will do a CPU version and custom op for it. Things that work at the moment
So I'm now at the roi heads (as you can see above), the box first. |
That's awesome progress @t-vi ! We were aware of the problems that Indexing with the JIT doesn't work very well yet (but we are improving support for it), so the approach you followed for the I didn't quite understand the problem with tracing the constant parameters, but I suppose this is a bug in upstream PyTorch? Thanks a lot for all your help! |
So I filed the two JIT observations as issues with PyTorch (see above). |
It seems that pytorch/pytorch#13564 had been fixed. |
Yes, and we managed to do tracing in #138. There is a "regression" in 1.0 that invalidates the merge_levels script, so you'd currently need to replace it with a (very straightforward) custom op. |
@t-vi , I met a core dump bug, when I executed the trace_model.py from your patch. The core information is below. |
Hello, any progress on this? I am also very interested. |
Hi, any progress on this? Anyone who managed to do this? |
I'm also interested in knowing about the progress on this. |
I get |
❓ Questions and Help
As we know,in pytorch1.0 Torch Script is a way to create serializable and optimizable models from PyTorch code. Any code written in Torch Script can be saved from Python process and loaded in a process where there is no Python dependency.
So will maskrcnn-benchmark support torch.jit.trace or torch.jit.script mode in the nearly future?
The text was updated successfully, but these errors were encountered: