From 2a9709c90beaf816607402b91b3e016b553375b3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 20 Mar 2023 10:05:44 -0400 Subject: [PATCH] [Unity][Frontend] FX exp and strided_slice fix (#14338) * Add the support of `exp` for the FX translator. * Previously the way FX translator dealt with `None` in torch tensor slice (e.g., `x[:, None, None]`) is not right. This PR fixes this issue. Specifically, the `None` here means dim expansion, and the previous impl mistakenly increases the dim counter when seeing `None`, which will lead to dim counter out-of-range issue in the end. --- .../tvm/relax/frontend/torch/fx_translator.py | 7 ++- tests/python/relax/test_frontend_from_fx.py | 48 +++++++++++++++++-- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 0bd987cf2dcd..a2e2afe668da 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -136,6 +136,9 @@ def _call_binary_op(self, op, lhs, rhs): def _cos(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.cos(self.env[node.args[0]])) + def _exp(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.exp(self.env[node.args[0]])) + def _sin(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) @@ -858,8 +861,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: axes.append(i) i = i + 1 elif index is None: - expand_dim.append(i) - i = i + 1 + expand_dim.append(len(axes) + len(expand_dim)) else: raise ValueError("Unsupported index type: " + str(type(index))) while i < len(shape): @@ -903,6 +905,7 @@ def create_convert_map(self): nn.modules.sparse.Embedding: self._embedding, # call_function and call_method "cos": self._cos, + "exp": self._exp, "sin": self._sin, "add": self._add, "floordiv": self._floordiv, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 31b43070cb64..2e69795d51ee 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -19,7 +19,7 @@ import tvm from tvm import relax import tvm.testing -from tvm.script.parser import relax as R, tir as T +from tvm.script.parser import ir as I, relax as R, tir as T def verify_model(torch_model, input_info, binding, expected): @@ -1372,8 +1372,6 @@ def test_getitem(): torch.set_grad_enabled(False) torch.random.manual_seed(0) - input_info = [([1, 3, 10, 10], "float32")] - class Slice1(Module): def forward(self, x): return x[0, 1::2, :, :3] @@ -1398,7 +1396,29 @@ def main( R.output(gv) return gv - verify_model(Slice1(), input_info, {}, expected1) + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + @I.ir_module + class expected2: + @R.function + def main( + inp_0: R.Tensor((8, 16), dtype="float32") + ) -> R.Tensor((8, 1, 1, 16, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( + inp_0, axes=[0, 1], begin=[0, 0], end=[8, 16], strides=[1, 1] + ) + lv1: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.reshape( + lv, R.shape([8, 1, 1, 16, 1]) + ) + gv: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Slice1(), [([1, 3, 10, 10], "float32")], {}, expected1) + verify_model(Slice2(), [([8, 16], "float32")], {}, expected2) @tvm.testing.requires_gpu @@ -1451,6 +1471,26 @@ def main( verify_model(Cos(), input_info, {}, expected2) + # exp + class Exp(Module): + def forward(self, input): + return torch.exp(input) + + @tvm.script.ir_module + class expected_exp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Exp(), input_info, {}, expected_exp) + # sqrt class Sqrt(Module): def forward(self, input):