diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a52b73f3abd4..3d150e3eedbf 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -975,9 +975,15 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: # functional.layer_norm if node.target not in self.named_modules: # static or symbolic - normalized_shape = ( - node.args[1] if type(node.args[1]) == tuple else self.env[node.args[1]] - ) + arg = node.args[1] + if isinstance(arg, tuple): + value = arg + else: + try: + value = self.env[arg] + except TypeError: + value = tuple(arg) + normalized_shape = value dim_num = len(normalized_shape) axes = list(range(-dim_num, 0)) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c81505107681..a1acff4974b1 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1152,6 +1152,46 @@ def main( binding = {} verify_model(model, input_info, binding, expected2) + class LayerNorm3(Module): + def __init__(self, shape): + super().__init__() + self.shape = shape + self.weight = torch.nn.Parameter(torch.ones(shape)) + self.bias = torch.nn.Parameter(torch.zeros(shape)) + + def forward(self, input): + return torch.nn.functional.layer_norm(input, self.shape, self.weight, self.bias, 1e-5) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor([10, 10], dtype="float32"), + w2: R.Tensor([10, 10], dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = LayerNorm3([10, 10]) + binding = { + "w1": model.weight.detach().numpy(), + "w2": model.bias.detach().numpy(), + } + verify_model(model, input_info, binding, expected3) + def test_cross_entropy(): input_info = [([3, 2], "float32"), ([3], "int32")]