-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add scoring methods in Luong-style attention #15867
Conversation
Add scoring methods, particularly conact for Luong-based attention.
Added the tests test_shape_with_key_concat, test_shape_concat, test_calculate_scores_one_dim_with_scale_concat and test_calculate_scores_multi_dim_concat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/layers/dense_attention.py
Outdated
@@ -242,6 +242,8 @@ class Attention(BaseDenseAttention): | |||
Defaults to `False`. | |||
dropout: Float between 0 and 1. Fraction of the units to drop for the | |||
attention scores. Defaults to 0.0. | |||
score: One of {'dot', 'concat'}. 'dot' refers to dot multiplication | |||
of query and key. 'concat' refers to hyperbolic tangent of query and key. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's tanh
of the sum. Why is this called concat
in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's tanh of the sum
Its not exactly tanh of the sum. Key and query are concatenated much like score in AdditiveAttention layer (Bahdanau-style attention).
Why is this called concat in this case?
In the paper they call this score concat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good then, but please improve the argument description. Currently the description seems rather confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this okay?
"score_mode: One of {'dot', 'concat'}. 'dot' refers to dot product of query and key. 'concat' refers to hyperbolic tangent of query and key concatenated."
Do we need to describe the parameter score_type too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use:
score_mode: Function to use to compute attention scores, one of `{"dot", "concat"}`.
`"dot"` refers to the dot product between the query and key vectors.
`"concat"` refers to the hyperbolic tangent of the concatenation of the query and key vectors.
keras/layers/dense_attention.py
Outdated
super(Attention, self).__init__(**kwargs) | ||
self.use_scale = use_scale | ||
self.score_type= score | ||
if self.score_type not in ['dot', 'concat']: | ||
logging.warning(f'Score type {self.score_type} is unknown, ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a ValueError
Did the requested changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
if self.score_mode not in ['dot', 'concat']: | ||
raise ValueError( | ||
"Unknown score_mode. Acceptable values " | ||
"are: ['dot', 'concat']" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add:
f"Received: score_mode={score_mode}"
Also make sure you use '
and "
consistently
keras/layers/dense_attention.py
Outdated
@@ -242,6 +242,8 @@ class Attention(BaseDenseAttention): | |||
Defaults to `False`. | |||
dropout: Float between 0 and 1. Fraction of the units to drop for the | |||
attention scores. Defaults to 0.0. | |||
score: One of {'dot', 'concat'}. 'dot' refers to dot multiplication | |||
of query and key. 'concat' refers to hyperbolic tangent of query and key. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use:
score_mode: Function to use to compute attention scores, one of `{"dot", "concat"}`.
`"dot"` refers to the dot product between the query and key vectors.
`"concat"` refers to the hyperbolic tangent of the concatenation of the query and key vectors.
@fchollet can you please take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
keras/layers/dense_attention_test.py
Outdated
dtype=np.float32) | ||
attention_layer = dense_attention.Attention(score_mode='concat') | ||
attention_layer.build(input_shape=([1, 2, 4], [1, 3, 4])) | ||
actual = attention_layer._calculate_scores(query=q, key=k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test is failing: https://source.cloud.google.com/results/invocations/acd8077d-49cf-4ad9-a164-2c37718c3a47/targets/keras%2Fgithub%2Fubuntu%2Fcpu%2Fpresubmit/log
You may need to do actual = keras.backend.get_value(actual)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because I am not initializing the parameter attention_v
and only create it if score_mode= 'concat'
? Because score_mode='dot'
is passing and the scale parameter is also initialized to None at the start itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what do you get when you replace it with actual = keras.backend.get_value(actual)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To maintain the uniformity in the tests I havent used this yet.
The error msg is :
Node: 'ReadVariableOp'
Could not find variable attention_v. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/attention_v)
[[{{node ReadVariableOp}}]]
So initialized attention_v
(now concat_score_weight
) to None just like scale
is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error is unrelated to this. It will likely go away if you use get_value
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did the requested changes.
keras/layers/dense_attention_test.py
Outdated
attention_layer = dense_attention.Attention(score_mode='concat') | ||
attention_layer.build(input_shape=([1, 2, 4], [1, 3, 4])) | ||
actual = attention_layer._calculate_scores(query=q, key=k) | ||
attention_layer.attention_v = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice that it isn't clear what attention_v
means. Can you rename it to a fully spelled out variable name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so in concat
score is the product of a learnable parameter (here attention_v
) and the hyperbolic tangent of the concatenation of the query and key vectors.
v right here is what I am calling attention_v
and as usual Wa is the scaling parameter and ht and hs are query and key respcetively.
A pytorch implementation of the same can be found here in code block 11.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename it to something that isn't an abbreviation, like attention_weights
. See the Keras API style guide comments on naming: https://github.com/keras-team/governance/blob/master/keras_api_design_guidelines.md#carefully-weigh-whether-a-new-feature-should-be-included
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed attention_v
to concat_score_weight
. Does that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update -- let's try merging again.
Imported from GitHub PR #15867 Luong-style attention attention use three types of scoring methods, namely dot, general and concat. This can be found in the 3rd page of... PiperOrigin-RevId: 422945393
Imported from GitHub PR #15867 Luong-style attention attention use three types of scoring methods, namely dot, general and concat. This can be found in the 3rd page of... PiperOrigin-RevId: 422945393
Imported from GitHub PR #15867 Luong-style attention attention use three types of scoring methods, namely dot, general and concat. This can be found in the 3rd page of... PiperOrigin-RevId: 422950111
Imported from GitHub PR #15867... PiperOrigin-RevId: 423145187
Imported from GitHub PR #15867... PiperOrigin-RevId: 423145187
Imported from GitHub PR #15867... PiperOrigin-RevId: 423145187
Imported from GitHub PR #15867... PiperOrigin-RevId: 423165160
Luong-style attention attention use three types of scoring methods, namely dot, general and concat. This can be found in the 3rd page of the original paper and explained here.
Implementation of the Layer can be found in this notebook
Auto closes #15866