diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 37298f3a9..62edd7caa 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1769,8 +1769,9 @@ def aten_scaled_dot_product_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + enable_gqa: bool = False, ) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor + """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -1790,6 +1791,10 @@ def aten_scaled_dot_product_attention( is_causal and attn_mask is None ), "is_causal and attn_mask cannot be set at the same time" + assert ( + not enable_gqa + ), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html if scale is None: scale = _attention_scale(query) @@ -1982,8 +1987,9 @@ def aten_scaled_dot_product_attention_bool_mask( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + enable_gqa: bool = False, ) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor + """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -2003,6 +2009,10 @@ def aten_scaled_dot_product_attention_bool_mask( is_causal and attn_mask is None ), "is_causal and attn_mask cannot be set at the same time" + assert ( + not enable_gqa + ), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + if scale is None: scale = _attention_scale(query) scale = op.CastLike(scale, query)