Skip to content

Commit

Permalink
Update vit.jl
Browse files Browse the repository at this point in the history
st.attention -> st.attention_dropout in vit.jl
  • Loading branch information
aksuhton authored Sep 21, 2023
1 parent 1b0334e commit 88bdb80
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/vision/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function (m::MultiHeadAttention)(x::AbstractArray{T, 3}, ps, st) where {T}
seq_len * batch_size)

attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale)
attention, st_attention = m.attention_dropout(attention, ps.attention_dropout,
attention, st_attention_dropout = m.attention_dropout(attention, ps.attention_dropout,
st.attention_dropout)

value_reshaped = reshape(value, nfeatures ÷ m.number_heads, m.number_heads,
Expand All @@ -50,7 +50,7 @@ function (m::MultiHeadAttention)(x::AbstractArray{T, 3}, ps, st) where {T}
y, st_projection = m.projection(reshape(pre_projection, size(pre_projection, 1), :),
ps.projection, st.projection)

st_ = (qkv_layer=st_qkv, attention=st_attention, projection=st_projection)
st_ = (qkv_layer=st_qkv, attention_dropout=st_attention_dropout, projection=st_projection)
return reshape(y, :, seq_len, batch_size), st_
end

Expand Down

0 comments on commit 88bdb80

Please sign in to comment.