From 0e77cf6106e609ef7899dc99adafd7469656a243 Mon Sep 17 00:00:00 2001 From: suryasidd Date: Thu, 8 Feb 2024 17:46:24 -0800 Subject: [PATCH] Fixed linter issues --- .../pytorch/torchdynamo/decompositions.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py index 646cf75b916a87..bf42f0df29ebc5 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/decompositions.py @@ -1,6 +1,7 @@ import torch from torch._decomp.decompositions import aten, pw_cast_for_opmath -from torch._decomp import decomposition_table, register_decomposition, get_decompositions +from torch._decomp import register_decomposition + @register_decomposition(aten.convolution_backward) @pw_cast_for_opmath @@ -27,7 +28,7 @@ def convolution_backward( # Compute the gradient of the weight tensor grad_weight = torch.nn.functional.conv_transpose2d( - input, weight.transpose(0,1), stride=stride, padding=padding, dilation=dilation, groups=groups, output_padding=output_padding + input, weight.transpose(0, 1), stride=stride, padding=padding, dilation=dilation, groups=groups, output_padding=output_padding ) # Compute the gradient of the bias tensor @@ -39,19 +40,17 @@ def convolution_backward( return grad_input, grad_weight, grad_bias - @register_decomposition(aten._scaled_dot_product_flash_attention.default) def scaled_dot_product_flash_attention( query, key, value, - dropout_p = 0.0, - is_causal = False, + dropout_p=0.0, + is_causal=False, *, - return_debug_mask = False, - scale = None, + return_debug_mask=False, + scale=None, ): - dtype = query.dtype batchSize, num_head, qSize, headSize = ( query.shape[0], query.shape[1], @@ -77,7 +76,7 @@ def scaled_dot_product_flash_attention( query, key, value, None, dropout_p, is_causal, None, scale=scale ) - scores = torch.matmul(query, key.transpose(-2,-1))/ (key.size(-1) ** 0.5) + scores = torch.matmul(query, key.transpose(-2, -1)) / (key.size(-1) ** 0.5) logsumexp = torch.logsumexp(scores, dim=-1) output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) @@ -93,6 +92,7 @@ def scaled_dot_product_flash_attention( debug_attn_mask, ) + def get_aot_decomposition_list(): return ([torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._softmax.default,