Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Apr 26, 2024
1 parent 6b5228f commit de5499c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ def lower_scaled_dot_product_attention(
break

assert attention_node_replaced is not None
assert len(match.replacements) == 1

new_attention_node = match.replacements[0]

assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)

# If the attention operator had keyword-args, copy them to the new node
if attention_node_replaced.kwargs:
assert len(match.replacements) == 1
new_attention_node = match.replacements[0]
assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)
new_attention_node.kwargs = {**attention_node_replaced.kwargs}

# Set default args in new node:
Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from ..testing_utilities import DECIMALS_OF_AGREEMENT
from .harness import DispatchTestCase


Expand Down

0 comments on commit de5499c

Please sign in to comment.