Skip to content

Commit

Permalink
Fixed linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
suryasidd committed Feb 9, 2024
1 parent 903d40f commit 0e77cf6
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch._decomp.decompositions import aten, pw_cast_for_opmath
from torch._decomp import decomposition_table, register_decomposition, get_decompositions
from torch._decomp import register_decomposition


@register_decomposition(aten.convolution_backward)
@pw_cast_for_opmath
Expand All @@ -27,7 +28,7 @@ def convolution_backward(

# Compute the gradient of the weight tensor
grad_weight = torch.nn.functional.conv_transpose2d(
input, weight.transpose(0,1), stride=stride, padding=padding, dilation=dilation, groups=groups, output_padding=output_padding
input, weight.transpose(0, 1), stride=stride, padding=padding, dilation=dilation, groups=groups, output_padding=output_padding
)

# Compute the gradient of the bias tensor
Expand All @@ -39,19 +40,17 @@ def convolution_backward(
return grad_input, grad_weight, grad_bias



@register_decomposition(aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p = 0.0,
is_causal = False,
dropout_p=0.0,
is_causal=False,
*,
return_debug_mask = False,
scale = None,
return_debug_mask=False,
scale=None,
):
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
Expand All @@ -77,7 +76,7 @@ def scaled_dot_product_flash_attention(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)

scores = torch.matmul(query, key.transpose(-2,-1))/ (key.size(-1) ** 0.5)
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.size(-1) ** 0.5)
logsumexp = torch.logsumexp(scores, dim=-1)

output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
Expand All @@ -93,6 +92,7 @@ def scaled_dot_product_flash_attention(
debug_attn_mask,
)


def get_aot_decomposition_list():
return ([torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._softmax.default,
Expand Down

0 comments on commit 0e77cf6

Please sign in to comment.