Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add several op mapping in PyTorch frontend #6472

Merged
merged 7 commits into from
Sep 18, 2020
Merged

Conversation

yongwww
Copy link
Member

@yongwww yongwww commented Sep 15, 2020

@kevinthesun @zhiics @masahi pls help take a look

@masahi
Copy link
Member

masahi commented Sep 15, 2020

So we simply ignore "inplace-ness" of copy_, is that right?

cc @t-vi I remember we had discussion of copy_ semantics before

@yongwww yongwww changed the title Add copy_ and clamp_ in PyTorch frontend Add several op mapping in PyTorch frontend Sep 15, 2020
@yongwww
Copy link
Member Author

yongwww commented Sep 15, 2020

@masahi we found copy_ exists in mask-rcnn model exported from https://github.com/facebookresearch/maskrcnn-benchmark, the mapping in this pr is good enough to support that and the test cases I prepared in this pr.

Have added a test cases class CopyInPlace(Module) to cover this inplace copy, we can pass the test cases, reason is the jit graph would not point to the original tensor after torch.Tensor.copy_. For example, in the following code snippet, to use a after copy_

a = torch.rand((2, 3))
b = torch.rand((2, ))
c = torch.Tensor.copy_(a, b)
return a # a is supposed to be original one, but in jit graph, it is the same as c

related graph:

 graph(%self : __torch__.___torch_mangle_10.CopyInPlace,
      %a : Float(2, 3, 5),
      %b : Float(2, 3, 5)):
  %3 : bool = prim::Constant[value=0]()
  %4 : Float(2, 3, 5) = aten::copy_(%a, %b, %3)
  return (%4)

@masahi
Copy link
Member

masahi commented Sep 15, 2020

yeah, that is also what happens with other inplace ops. TorchScript graph has distinct input/output tensors for each op, even though for inplace op output and one of input tensors point to the same memory.

@yongwww
Copy link
Member Author

yongwww commented Sep 15, 2020

yeah, that is also what happens with other inplace ops. TorchScript graph has distinct input/output tensors for each op, even though for inplace op output and one of input tensors point to the same memory.

I see. Interesting, it is more like a TorchScript issue then

@t-vi
Copy link
Contributor

t-vi commented Sep 15, 2020

Note that maskrcnn-benchmark isn't necessarily intended to be scripted, the worthwhile things have been incorporated into torchvision and improved there.

The trouble with in-place operators is that their semantics are not functional (i.e. they modify their inputs) and thus cannot be mapped to TVM. For example that when operated on views (e.g. slices of a tensor, things that come from view), they will change the origin (a[2:4] = b is a slice + a tensor modifying a).

What you would need to do is to preprocess the graph to remove these side-effects - for example if you can exclude that the input to clamp_ is a view (e.g. because it comes out of a convolution) and that it is not used anywhere else, you can replace it with clamp and proceed. It does require having an opinion on which tensors might share memory (the alias analysis in the PyTorch JIT does that).

The last discussion we had on this was #6049 .

@masahi
Copy link
Member

masahi commented Sep 15, 2020

What you would need to do is to preprocess the graph to remove these side-effects - for example if you can exclude that the input to clamp_ is a view (e.g. because it comes out of a convolution) and that it is not used anywhere else, you can replace it with clamp and proceed. It does require having an opinion on which tensors might share memory (the alias analysis in the PyTorch JIT does that).

Seems PyTorch already has a set of passes to remove mutation. For example, this pass looks like it removes inplace ops. https://github.com/pytorch/pytorch/blob/61623430d32c081124067fc4abf33bd4a023af90/torch/csrc/jit/passes/remove_inplace_ops.cpp#L35

Would be interesting to run torch._C._jit_pass_remove_inplace_ops(...) and see what happens

@yongwww
Copy link
Member Author

yongwww commented Sep 15, 2020

Would be interesting to run torch._C._jit_pass_remove_inplace_ops(...) and see what happens

the graph is still the same as before

@t-vi
Copy link
Contributor

t-vi commented Sep 15, 2020

Seems PyTorch already has a set of passes to remove mutation

Note fat warning at the top of the class, though.

A typical use of copy is something like this:

In [1]: import torch

In [2]: @torch.jit.script
   ...: def update(x, y):
   ...:     x[0] = y
   ...:     return x
   ...: 

In [3]: update.graph
Out[3]: 
graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %7 : bool = prim::Constant[value=0]()
  %4 : int = prim::Constant[value=0]() # <ipython-input-2-e9bf1f9acfa9>:3:6
  %6 : Tensor = aten::select(%x.1, %4, %4) # <ipython-input-2-e9bf1f9acfa9>:3:4
  %8 : Tensor = aten::copy_(%6, %y.1, %7) # <ipython-input-2-e9bf1f9acfa9>:3:4
  return (%x.1)

@masahi
Copy link
Member

masahi commented Sep 15, 2020

the graph is still the same as before

Yeah, this pass only supports a very small subset of inplace ops, see
https://github.com/pytorch/pytorch/blob/61623430d32c081124067fc4abf33bd4a023af90/torch/csrc/jit/passes/remove_inplace_ops.cpp#L7-L12

It does seem though that their onnx export use inplace op removal (using a different pass than the one above) before they emit onnx.

@yongwww yongwww force-pushed the ptod_maping branch 2 times, most recently from b474523 to df27ece Compare September 15, 2020 08:32
@yongwww
Copy link
Member Author

yongwww commented Sep 15, 2020

Seems PyTorch already has a set of passes to remove mutation

Note fat warning at the top of the class, though.

A typical use of copy is something like this:

In [1]: import torch

In [2]: @torch.jit.script
   ...: def update(x, y):
   ...:     x[0] = y
   ...:     return x
   ...: 

In [3]: update.graph
Out[3]: 
graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %7 : bool = prim::Constant[value=0]()
  %4 : int = prim::Constant[value=0]() # <ipython-input-2-e9bf1f9acfa9>:3:6
  %6 : Tensor = aten::select(%x.1, %4, %4) # <ipython-input-2-e9bf1f9acfa9>:3:4
  %8 : Tensor = aten::copy_(%6, %y.1, %7) # <ipython-input-2-e9bf1f9acfa9>:3:4
  return (%x.1)

just tested with a similar test case as the one you provided, the slicing assignment doesn't take effect in jit graph.

   ```
    # test case
    def forward(self, *args):
        a = args[0]
        b = args[1]
        a[0] = b
        return a
    ```

related graph:

graph(%self : __torch__.Copy,
      %a : Float(2, 3),
      %b : Float(1, 3)):
  %3 : int = prim::Constant[value=0]() 
  %4 : int = prim::Constant[value=0]() 
  %5 : Float(3) = aten::select(%a, %3, %4) # 
  %6 : int = prim::Constant[value=3]() 
  %7 : int[] = prim::ListConstruct(%6)
  %8 : Float(3) = aten::view(%b, %7) 
  %9 : bool = prim::Constant[value=0]()
  %10 : Float(3) = aten::copy_(%5, %8, %9) 
  return (%a)

@t-vi
Copy link
Contributor

t-vi commented Sep 15, 2020

@yongwww what does "the slicing assignment doesn't take effect in jit" mean? The JITed function does update a because select will return a view that uses the same memory as a and then copy_ writes to that. I tried with an a that is all zeros and b that contains random.

Which thing in maskrcnn-benchmark produces the copy_ and why is your translation correct? Maybe we can eliminate the copy_ beforehand.

@yongwww
Copy link
Member Author

yongwww commented Sep 15, 2020

@t-vi I see, thanks for pointing the select out, yeah, for your case, the jit graph is okay and works fine with PyTorch. But seems tvm doesn't support this case without real in-place operation. For a reference, we can see the slicing stuff here https://github.com/facebookresearch/maskrcnn-benchmark/blob/4ac0a3b440320ff0d0159a280f2897d1f9f996f3/maskrcnn_benchmark/modeling/box_coder.py#L85-L93

@yongwww
Copy link
Member Author

yongwww commented Sep 15, 2020

@masahi @t-vi @zhiics @kevinthesun
to sum up the discussion above. The mapping in this pr is able to support some test cases (added in the test_forward) which are not real in-place copy.

For test case @t-vi provided, the in-place copy is required, and the implementation in this pr doesn't work for that.

I would suggest adding a warning and todo in this mapping since some other merged in-place op mapping should has the similar issues, don't have a quick solution for this in my mind at this point. Any suggestion?

@masahi
Copy link
Member

masahi commented Sep 15, 2020

@yongwww to be clear, this translation works for box_coder from maskrcnn-benchmark you mentioned above? https://github.com/facebookresearch/maskrcnn-benchmark/blob/4ac0a3b440320ff0d0159a280f2897d1f9f996f3/maskrcnn_benchmark/modeling/box_coder.py#L85-L93 That looks similar to the test case @t-vi provided.

By "work", of course I mean we get the same output as pytorch would.

I would suggest adding a warning and todo in this mapping since some other merged in-place op mapping should has the similar issues, don't have a quick solution for this in my mind at this point. Any suggestion?

I'm ok with this, if we can somehow detect inplace op patterns that we cannot support, we should immediately abort, rather than giving a warning and emitting a wrong relay graph.

@yongwww
Copy link
Member Author

yongwww commented Sep 16, 2020

@masahi the copy_ mapping in this pr doesn't work for the code snippet in the box_coder I pointed, similar to that @t-vi provided, slicing assignment doesn't take effect. After modifying the model definition script, we get copy_ removed. clamp_ in this pr is also an inplace op, and it works well for the maskrcnn-benchmark.

@yongwww
Copy link
Member Author

yongwww commented Sep 16, 2020

@masahi Just checked the op list under torch.Tensor, there are a bunch of inplace ops there, seems the op with suffix _ is inplace one, https://discuss.pytorch.org/t/how-to-identify-the-in-place-operation-in-pytorch/18592.

@masahi
Copy link
Member

masahi commented Sep 16, 2020

@masahi Just checked the op list under torch.Tensor, there are a bunch of inplace ops there, seems the op with suffix _ is inplace one, https://discuss.pytorch.org/t/how-to-identify-the-in-place-operation-in-pytorch/18592.

Yes, that is also my understanding. Removing and not supporting copy_ sounds good, as it is commonly used for a pattern we cannot support, but for other inplace ops like clamp_ and others already in the frontend, I think we can keep them. What do you think @t-vi ?

@yongwww yongwww closed this Sep 16, 2020
@yongwww yongwww reopened this Sep 16, 2020
@t-vi
Copy link
Contributor

t-vi commented Sep 16, 2020

Yeah, any op ending in _ will be inplace and there are many.

Without keeping track of memory locations some construct to represent copy_ we likely cannot faithfully represent all possible uses of inplace ops.

So without a pure way, we have to balance the two

  • With inplace, we will take in idioms we cannot support if we do so without further analysis. (Because the example would also apply to clamp_, as seen below.)
  • Without things like relu_, we might exclude many networks (inplace relu_ after something where the output isn't needed for the backward is quite popular) that would otherwise run fine.

This appears also to be the rationale for the ad-hoc inplace removal for ONNX export, and I'm sure the PyTorch-ONNX people had a good think (and they have more alias analysis available to them) if there are better options. So I guess whatever they let go through by just assuming inplace can be made out-of-place might also be desirable to treat this way here. Maybe the list they use is a good guide.

I would imagine that someone somewhere does things like

def pn_relu(x):
   x[0].clamp_(min=0)
   x[1].clamp_(max=0)

(so the nonzero bits would be normally distributed for normally distributed inputs to be clever about exploding/vanishing activations), but I don't think it's terribly popular, so I would not worry about it.

For maskrcnn-benchmark in particular, I did a bit of tracing almost two years ago when I experimented with my PyTorch POC port to Android (but the tracing PR was never merged) and should be easy to remove their inplace use. But it would seem that it's more useful to focus on the TorchVision implementations, maskrcnn-benchmark is really dormant and my impression was that it itself wasn't fixed much but instead the lessons learned from there have flown into TorchVision and people have worked on JIT/ONNX/Quantization properties there.

@yongwww
Copy link
Member Author

yongwww commented Sep 17, 2020

@masahi @t-vi seems the translation for every inplace op has similar issue as copy_, considering the translation works for some cases (those w/o real inplace), I would like to still keep copy_ and clamp_ in this pr, and add warning for all implace op. How do you think? getting more model supported is the first step for us, if overall inference result is not correct, then one of the first things is to look at the implace op mapping warning, does this make sense to you?

or do I need to remove copy_ and clamp_? we did encounter some models with clamp_ and some other implace ops from customers, and we didn't see the accuracy issue frequently, having implace op translation offered value for us. Looking forward to a general and well-performing solution for implace support in tvm in a long term

@masahi
Copy link
Member

masahi commented Sep 17, 2020

Is there a valid usage of copy_ that we can (and you want to) support? From the discussion so far, I was under impression that copy_ is always generated from an idiom that requires true inplace support.

@yongwww
Copy link
Member Author

yongwww commented Sep 17, 2020

Is there a valid usage of copy_ that we can (and you want to) support? From the discussion so far, I was under impression that copy_ is always generated from an idiom that requires true inplace support.

For example, the copy_ translation works for the test cases I provided in test_forward.py. But the case from t_vi is more common for which the translation doesn't work.

@masahi
Copy link
Member

masahi commented Sep 17, 2020

I mean, what kind of real world PyTorch code ends up generating copy_? Nobody writes torch.copy_() explicitly right? If the usage of copy_ is common and it is safe if we ignore inplace, then having support for makes sense.

@yongwww
Copy link
Member Author

yongwww commented Sep 17, 2020

I mean, what kind of real world PyTorch code ends up generating copy_? Nobody writes torch.copy_() explicitly right?

I won't write model using torch.copy_(). seems copy_ happened in the similar cases to that t-vi provided in real world application. Does keeping clamp_ and removing copy_ sound good to you? I am okay to remove any one of them

@masahi
Copy link
Member

masahi commented Sep 17, 2020

Yes, I think we can keep clamp_, but copy_ seems to require true inplace always, so we shouldn't support it (even if we add a conversion of copy_, the output from TVM would be wrong, so it is better not to support it).

@yongwww
Copy link
Member Author

yongwww commented Sep 17, 2020

Yes, I think we can keep clamp_, but copy_ seems to require true inplace always, so we shouldn't support it (even if we add a conversion of copy_, the output from TVM would be wrong, so it is better not to support it).

makes sense to me. Thanks, will do that.

@masahi
Copy link
Member

masahi commented Sep 17, 2020

I also want to take some time studying how PyTorch ONNX export handles inplace ops, and whether we can reuse some of their passes.

@masahi masahi merged commit 28ea54a into apache:master Sep 18, 2020
@masahi
Copy link
Member

masahi commented Sep 18, 2020

Thanks @yongwww @t-vi

kevinthesun pushed a commit to kevinthesun/tvm that referenced this pull request Sep 18, 2020
* Add copy_ and clamp_ in PyTorch frontend

* add true_divide in PyTorch frontend

* more test cases for copy_

* fix format

* remove copy_

* fix format

* skip true_divide for torch < 1.5
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 18, 2020
* Add copy_ and clamp_ in PyTorch frontend

* add true_divide in PyTorch frontend

* more test cases for copy_

* fix format

* remove copy_

* fix format

* skip true_divide for torch < 1.5
@yongwww yongwww deleted the ptod_maping branch December 8, 2020 12:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants