Skip to content

Commit

Permalink
[Unity][Frontend] FX exp and strided_slice fix (#14338)
Browse files Browse the repository at this point in the history
* Add the support of `exp` for the FX translator.
* Previously the way FX translator dealt with `None` in torch tensor
slice (e.g., `x[:, None, None]`) is not right. This PR fixes this issue.
Specifically, the `None` here means dim expansion, and the previous impl
mistakenly increases the dim counter when seeing `None`, which will lead
to dim counter out-of-range issue in the end.
  • Loading branch information
MasterJH5574 authored Mar 20, 2023
1 parent 57b42a8 commit 2a9709c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def _call_binary_op(self, op, lhs, rhs):
def _cos(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))

def _exp(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))

def _sin(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))

Expand Down Expand Up @@ -858,8 +861,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var:
axes.append(i)
i = i + 1
elif index is None:
expand_dim.append(i)
i = i + 1
expand_dim.append(len(axes) + len(expand_dim))
else:
raise ValueError("Unsupported index type: " + str(type(index)))
while i < len(shape):
Expand Down Expand Up @@ -903,6 +905,7 @@ def create_convert_map(self):
nn.modules.sparse.Embedding: self._embedding,
# call_function and call_method
"cos": self._cos,
"exp": self._exp,
"sin": self._sin,
"add": self._add,
"floordiv": self._floordiv,
Expand Down
48 changes: 44 additions & 4 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm
from tvm import relax
import tvm.testing
from tvm.script.parser import relax as R, tir as T
from tvm.script.parser import ir as I, relax as R, tir as T


def verify_model(torch_model, input_info, binding, expected):
Expand Down Expand Up @@ -1372,8 +1372,6 @@ def test_getitem():
torch.set_grad_enabled(False)
torch.random.manual_seed(0)

input_info = [([1, 3, 10, 10], "float32")]

class Slice1(Module):
def forward(self, x):
return x[0, 1::2, :, :3]
Expand All @@ -1398,7 +1396,29 @@ def main(
R.output(gv)
return gv

verify_model(Slice1(), input_info, {}, expected1)
class Slice2(Module):
def forward(self, x):
return x[:, None, None, :, None]

@I.ir_module
class expected2:
@R.function
def main(
inp_0: R.Tensor((8, 16), dtype="float32")
) -> R.Tensor((8, 1, 1, 16, 1), dtype="float32"):
with R.dataflow():
lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice(
inp_0, axes=[0, 1], begin=[0, 0], end=[8, 16], strides=[1, 1]
)
lv1: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.reshape(
lv, R.shape([8, 1, 1, 16, 1])
)
gv: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = lv1
R.output(gv)
return gv

verify_model(Slice1(), [([1, 3, 10, 10], "float32")], {}, expected1)
verify_model(Slice2(), [([8, 16], "float32")], {}, expected2)


@tvm.testing.requires_gpu
Expand Down Expand Up @@ -1451,6 +1471,26 @@ def main(

verify_model(Cos(), input_info, {}, expected2)

# exp
class Exp(Module):
def forward(self, input):
return torch.exp(input)

@tvm.script.ir_module
class expected_exp:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Exp(), input_info, {}, expected_exp)

# sqrt
class Sqrt(Module):
def forward(self, input):
Expand Down

0 comments on commit 2a9709c

Please sign in to comment.