From 5b8dc2c7fbe8282a638e4d54ce09786255def691 Mon Sep 17 00:00:00 2001 From: Pedro Gonnet Date: Wed, 11 Jan 2023 06:09:43 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 501261604 --- .../models/multi_head_attention/layers.py | 59 +++++-- .../multi_head_attention/layers_test.py | 150 +++++++++++++++--- 2 files changed, 175 insertions(+), 34 deletions(-) diff --git a/tensorflow_gnn/models/multi_head_attention/layers.py b/tensorflow_gnn/models/multi_head_attention/layers.py index b58ec822..7ea8ac4f 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers.py +++ b/tensorflow_gnn/models/multi_head_attention/layers.py @@ -1,3 +1,4 @@ +# pyformat: mode=yapf # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +14,7 @@ # limitations under the License. # ============================================================================== """Contains a Multi-Head Attention and associated layers.""" -from typing import Any, Callable, Collection, Mapping, Optional, Union +from typing import Any, Callable, Collection, Literal, Mapping, Optional, Union import warnings import tensorflow as tf @@ -72,13 +73,13 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase): attended to each other, which means we do NOT compute $N^2$ pairs of scores as the original Transformer-style Attention. - Users are able to remove the scaling of attention scores (score_scaling=False) - or add an activation on the transformed query (controled by - `attention_activation`). However, we recommend to remove the scaling when - using an `attention_activation` since activating both of them may lead to - degrated accuracy. One can also customize the transformation kernels with - different intializers, regularizers as well as the use of bias terms, using - the other arguments. + Users are able to remove the scaling of attention scores + (`score_scaling="none"`) or add an activation on the transformed query + (controlled by `attention_activation`). However, we recommend to remove the + scaling when using an `attention_activation` since activating both of them may + lead to degraded accuracy. One can also customize the transformation kernels + with different initializers, regularizers as well as the use of bias terms, + using the other arguments. Example: Transformer-style attention on neighbors along incoming edges whose result is concatenated with the old node state and passed through @@ -157,9 +158,15 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase): only queries are transformed since the two transformations on queries and keys are equivalent to one. (The presence of transformations on values is independent of this arg.) - score_scaling: If true, the attention scores are divided by the square root - of the dimension of keys (i.e., per_head_channels if transform_keys=True, - else whatever the dimension of combined sender inputs is). + score_scaling: One of either `"none"`, `"rsqrt_dim"`, or + `"trainable_sigmoid"`. If set to `"rsqrt_dim"`, the attention scores are + divided by the square root of the dimension of keys (i.e., + `per_head_channels` if `transform_keys=True`, otherwise whatever the + dimension of combined sender inputs is). If set to `"trainable_sigmoid"`, + the scores are scaled with `sigmoid(x)`, where `x` is a trainable weight + of the model that is initialized to `-5.0`, which initially makes all the + attention weights equal and slowly ramps up as the other weights in the + layer converge. Defaults to `"rsqrt_dim"`. transform_values_after_pooling: By default, each attention head applies the value transformation, then pools with attention coefficients. Setting this option pools inputs with attention coefficients, then applies @@ -186,7 +193,8 @@ def __init__( kernel_regularizer: Union[None, str, tf.keras.regularizers.Regularizer] = None, transform_keys: bool = True, - score_scaling: bool = True, + score_scaling: Literal["none", "rsqrt_dim", + "trainable_sigmoid"] = "rsqrt_dim", transform_values_after_pooling: bool = False, **kwargs): kwargs.setdefault("name", "multi_head_attention_conv") @@ -222,7 +230,7 @@ def __init__( self._edge_dropout_layer = None # Check for conflicting options. - if attention_activation is not None and score_scaling: + if attention_activation is not None and score_scaling != "none": warnings.warn( "using both an activation on transformed inputs and score scaling " "may lead to degraded accuracy if the activation function restricts " @@ -300,6 +308,9 @@ def __init__( kernel_regularizer=kernel_regularizer, name="value_pooled") + if self._score_scaling == "trainable_sigmoid": + self._score_scaling_weight = None + def get_config(self): return dict( num_heads=self._num_heads, @@ -419,9 +430,27 @@ def convolve(self, # [num_items, *extra_dims, num_heads, 1] attention_coefficients = tf.expand_dims( tf.einsum("...j,...j->...", queries, keys), axis=-1) - if self._score_scaling: + + # Optionally scale the attention scores. + if self._score_scaling == "none": + pass + elif self._score_scaling == "rsqrt_dim": attention_coefficients *= tf.math.rsqrt( - tf.cast(keys.shape[-1], tf.float32)) + tf.cast(tf.shape(keys)[-1], tf.float32)) + elif self._score_scaling == "trainable_sigmoid": + if self._score_scaling_weight is None: + self._score_scaling_weight = self.add_weight( + name="score_scaling", + shape=[], + dtype=tf.float32, + initializer=tf.keras.initializers.Constant(-5.0), + trainable=True, + ) + attention_coefficients *= tf.keras.activations.sigmoid( + self._score_scaling_weight) + else: + raise ValueError("Unknown value MultiHeadAttentionConv(" + f"score_scaling='{self._score_scaling}')") attention_coefficients = extra_receiver_ops["softmax"]( attention_coefficients) diff --git a/tensorflow_gnn/models/multi_head_attention/layers_test.py b/tensorflow_gnn/models/multi_head_attention/layers_test.py index 368d0716..f4286c4a 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers_test.py +++ b/tensorflow_gnn/models/multi_head_attention/layers_test.py @@ -1,3 +1,4 @@ +# pyformat: mode=yapf # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +16,7 @@ """Tests for Multi-Head Attention.""" import enum +import math import os from absl.testing import parameterized @@ -32,8 +34,7 @@ class ReloadModel(int, enum.Enum): class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase): - @parameterized.named_parameters(("", False), - ("TransformAfter", True)) + @parameterized.named_parameters(("", False), ("TransformAfter", True)) def testBasic(self, transform_values_after_pooling): """Tests that a single-headed MHA is correct given predefined weights.""" # NOTE: Many following tests use minor variations of the explicit @@ -121,13 +122,12 @@ def testBasic(self, transform_values_after_pooling): [0., 0., 0.]) else: # Same weights, but as Einsum kernel "hvc". - weights["multi_head_attention_conv/value_pooled/kernel:0"].assign( - [[ - [0., -1., 0.], - [-1., 0., 0.], - [-1., -1., 0.], - [0., 0., 1.1], - ]]) + weights["multi_head_attention_conv/value_pooled/kernel:0"].assign([[ + [0., -1., 0.], + [-1., 0., 0.], + [-1., -1., 0.], + [0., 0., 1.1], + ]]) weights["multi_head_attention_conv/value_pooled/bias:0"].assign( [[0., 0., 0.]]) @@ -165,8 +165,7 @@ def testBasic(self, transform_values_after_pooling): self.assertAllClose(got_2, want_2, atol=.0001) def testAttentionActivation(self): - """Tests that a single-headed MHA correctly applies attention activations. - """ + """Tests that a single-headed MHA correctly applies attention activations.""" # The same test graph as in the testBasic above. gt_input = _get_test_bidi_cycle_graph( @@ -177,8 +176,7 @@ def testAttentionActivation(self): ])) def get_conv(attention_activation=None): - """Constructs a MultiHeadAttentionConv with the given attention_activation. - """ + """Constructs a MultiHeadAttentionConv with the given attention_activation.""" conv = multi_head_attention.MultiHeadAttentionConv( num_heads=1, @@ -186,7 +184,7 @@ def get_conv(attention_activation=None): receiver_tag=tfgnn.TARGET, attention_activation=attention_activation, activation=None, - score_scaling=False, + score_scaling="none", ) _ = conv(gt_input, edge_set_name="edges") # Build weights. @@ -290,6 +288,122 @@ def get_conv(attention_activation=None): self.assertAllEqual(got.shape, (3, 3)) self.assertAllClose(got, want, atol=.0001) + def testScoreScalingTypes(self): + """Tests that the different types of score scaling are applied correctly.""" + + # The same test graph as in the testBasic above. + gt_input = _get_test_bidi_cycle_graph( + tf.constant([ + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 2.0], + [0.0, 0.0, 1.0, 3.0], + ])) + + def get_conv(score_scaling=None): + """Constructs a MultiHeadAttentionConv with the given score_scaling.""" + + conv = multi_head_attention.MultiHeadAttentionConv( + num_heads=1, + per_head_channels=3, + receiver_tag=tfgnn.TARGET, + activation=None, + score_scaling=score_scaling, + ) + + _ = conv(gt_input, edge_set_name="edges") # Build weights. + weights = {v.name: v for v in conv.trainable_weights} + if score_scaling == "trainable_sigmoid": + # Additional trainable weight for the score scaling. + self.assertLen(weights, 7) + else: + self.assertLen(weights, 6) + + weights["multi_head_attention_conv/query/kernel:0"].assign( + # The node states times the query kernel should be: + # + # [[0., 1., 0.], + # [0., 0., -1.], + # [1., 0., 0.]] + # + # i.e. the second query vector has negative values, which, after + # activation with the `relu` function, should be all zeros. + [ + [0.0, 1.0, 0.0], + [0.0, 0.0, -1.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ]) + weights["multi_head_attention_conv/query/bias:0"].assign([0.0, 0.0, 0.0]) + + weights["multi_head_attention_conv/key_node/kernel:0"].assign( + # The key_node kernel is chosen such that the the product with the + # node states is: + # + # [[-1., 0., 0.], + # [0., 1., 0.], + # [0., 0., 1.]] + # + # i.e. the third key vector has negative values, which, after + # activation with the `relu` function, should be all zeros. + [ + [-1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + ]) + weights["multi_head_attention_conv/key_node/bias:0"].assign( + [0.0, 0.0, 0.0]) + + # The attention scores are computed as the product of the transformed + # queries and keys (with a zero diagonal since there are no self edges and + # hence no self-attention), scaled by a factor s. + # + # s * [[0., 1., 0.], [[ 0, s, 0], + # [0., 0., -1], == [ 0, 0, -s], + # [-1, 0., 0.]] [-s, 0, 0]] + # + # Attention weights are computed by applying softmax to each row except + # the diagonal element. Recall that + # softmax([s, 0]) = [exp(s), 1] / (exp(s) + 1), and + # softmax([0, -s]) = softmax([s, 0]) = [exp(s), 1] / (exp(s) + 1), + # which explains the expected values below, with w = exp(s). + + weights["multi_head_attention_conv/value_node/kernel:0"].assign( + # Identity matrix such that the transformed node states are `eye(3)`. + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + ]) + weights["multi_head_attention_conv/value_node/bias:0"].assign( + [0.0, 0.0, 0.0]) + + return conv + + named_scalings = { + "none": 1.0, + "rsqrt_dim": 1.0 / math.sqrt(3.0), + "trainable_sigmoid": tf.keras.activations.sigmoid(-5.0), + } + + for scaling_name, scaling_factor in named_scalings.items(): + with self.subTest(f"with_{scaling_name}"): + conv = get_conv(score_scaling=scaling_name) + got = conv(gt_input, edge_set_name="edges") + + # Since the transformed values are just the identity matrix, we recover + # the attention weights for each query. + w = tf.math.exp(scaling_factor).numpy() + want = tf.constant([ + [0.0, w, 1.0], + [w, 0.0, 1.0], + [1.0, w, 0.0], + ]) / tf.constant( + w + 1.0, dtype=tf.float32) + self.assertAllEqual(got.shape, (3, 3)) + self.assertAllClose(got, want, atol=0.0001) + def testNoTransformKeys(self): """Tests that the no key transformation variant of MHA is correct.""" @@ -386,8 +500,7 @@ def testNoTransformKeys(self): self.assertAllEqual(got_2.shape, (3, 2, 3)) self.assertAllClose(got_2, want_2, atol=.0001) - @parameterized.named_parameters(("", False), - ("TransformAfter", True)) + @parameterized.named_parameters(("", False), ("TransformAfter", True)) def testMultihead(self, transform_values_after_pooling): """Extends testBasic with multiple attention heads.""" # The same test graph as in the testBasic above. @@ -404,7 +517,7 @@ def testMultihead(self, transform_values_after_pooling): receiver_tag=tfgnn.TARGET, activation="relu", use_bias=False, # Don't create /bias variables. - score_scaling=False, # Disable score scaling. + score_scaling="none", # Disable score scaling. transform_values_after_pooling=transform_values_after_pooling, ) @@ -475,8 +588,7 @@ def testMultihead(self, transform_values_after_pooling): self.assertAllClose(got, want, atol=.0001) @parameterized.named_parameters( - ("", ReloadModel.SKIP, False), - ("TransformAfter", ReloadModel.SKIP, True), + ("", ReloadModel.SKIP, False), ("TransformAfter", ReloadModel.SKIP, True), ("Restored", ReloadModel.SAVED_MODEL, False), ("RestoredTransformAfter", ReloadModel.SAVED_MODEL, True), ("RestoredKeras", ReloadModel.KERAS, False),