Skip to content

Commit

Permalink
Merge pull request #1120 from pytorch/fb-sync-khabinov
Browse files Browse the repository at this point in the history
[fx_acc] Add acc_tracer support for torch.mm
  • Loading branch information
Wei authored Jun 15, 2022
2 parents 058a511 + d897a0a commit c558de8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
Binary file removed docs/v1.1.0/._index.html
Binary file not shown.
7 changes: 7 additions & 0 deletions py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,13 @@ def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
("mat2", "other"),
],
)
@register_acc_op_mapping(
op_and_target=("call_function", torch.mm),
arg_replacement_tuples=[
("input", "input"),
("mat2", "other"),
],
)
@register_acc_op_mapping(op_and_target=("call_function", torch.matmul))
@register_acc_op
def matmul(*, input, other):
Expand Down

0 comments on commit c558de8

Please sign in to comment.