From 9ff486655edcff2573d6bced12a7498ae71741d0 Mon Sep 17 00:00:00 2001 From: katerynaCh Date: Fri, 15 Apr 2022 10:17:41 +0300 Subject: [PATCH 1/2] added sanbof models --- ...attention-neural-bag-of-feature-learner.md | 4 +- .../attention_neural_bag_of_feature/README.md | 1 + .../algorithm/models.py | 57 +++++++++++--- .../algorithm/samodels.py | 63 +++++++++++++++ ...attention_neural_bag_of_feature_learner.py | 77 ++++++++++--------- 5 files changed, 156 insertions(+), 46 deletions(-) create mode 100644 src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/samodels.py diff --git a/docs/reference/attention-neural-bag-of-feature-learner.md b/docs/reference/attention-neural-bag-of-feature-learner.md index e0ddaf6fe3..3b323a2e74 100644 --- a/docs/reference/attention-neural-bag-of-feature-learner.md +++ b/docs/reference/attention-neural-bag-of-feature-learner.md @@ -30,9 +30,9 @@ AttentionNeuralBagOfFeatureLearner(self, in_channels, series_length, n_class, n_ - **quantization_type**: *{"nbof", "tnbof"}, default="nbof"* Specifies the type of quantization layer. There are two types of quantization layer: the logistic neural bag-of-feature layer ("nbof") or the temporal logistic bag-of-feature layer ("tnbof"). - - **attention_type**: *{"spatial", "temporal"}, default="spatial"* + - **attention_type**: *{"spatial", "temporal", "spatialsa", "temporalsa", "spatiotemporal"}, default="spatial"* Specifies the type of attention mechanism. - There are two types of attention: the spatial attention mechanism that focuses on the different codewords ("spatial") or the temporal attention mechanism that focuses on different temporal instances ("temporal"). + There are two types of attention: the spatial attention mechanism that focuses on the different codewords ("spatial") or the temporal attention mechanism that focuses on different temporal instances ("temporal"). Additionaly, there are three self-attention based mecahnisms as described [here](https://arxiv.org/abs/2201.11092). - **lr_scheduler**: *callable, default=`opendr.perception.heart_anomaly_detection.attention_neural_bag_of_feature.attention_neural_bag_of_feature_learner.get_cosine_lr_scheduler(2e-4, 1e-5)`* Specifies the function that computes the learning rate, given the total number of epochs `n_epoch` and the current epoch index `epoch_idx`. That is, the optimizer uses this function to determine the learning rate at a given epoch index. diff --git a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/README.md b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/README.md index 3bb00de38a..75465763f1 100644 --- a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/README.md +++ b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/README.md @@ -5,3 +5,4 @@ This module provides the implementation of the Attention Neural Bag-of-Features ## Sources The algorithm is implemented according to the paper [Attention-based Neural Bag-of-Features Learning For Sequence Data](https://arxiv.org/abs/2005.12250). +Additionally, three self-attention mechanisms as described in [Self-Attention Neural Bag-of-Features](https://arxiv.org/abs/2201.11092) are implemented. diff --git a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/models.py b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/models.py index 9abbc42ba3..7eaa8ff5e6 100644 --- a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/models.py +++ b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/models.py @@ -23,6 +23,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from opendr.perception.heart_anomaly_detection.attention_neural_bag_of_feature.algorithm.samodels import SelfAttention class ResidualBlock(nn.Module): @@ -257,19 +258,34 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr # nbof block in_channels, series_length = self.compute_intermediate_dimensions(in_channels, series_length) self.quantization_block = NBoF(in_channels, n_codeword) - + self.att_type = att_type + out_dim = n_codeword # attention block - self.attention_block = Attention(n_codeword, series_length, att_type) + if att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']: + self.attention_block = SelfAttention(n_codeword, series_length, att_type) + self.attention_block2 = SelfAttention(n_codeword, series_length, att_type) + self.attention_block3 = SelfAttention(n_codeword, series_length, att_type) + out_dim = n_codeword*3 + else: + self.attention_block = Attention(n_codeword, series_length, att_type) # classifier - self.classifier = nn.Sequential(nn.Linear(in_features=n_codeword, out_features=512), + self.classifier = nn.Sequential(nn.Linear(in_features=out_dim, out_features=512), nn.ReLU(), nn.Dropout(dropout), nn.Linear(in_features=512, out_features=n_class)) def forward(self, x): x = self.resnet_block(x) - x = self.attention_block(self.quantization_block(x)).mean(-1) + x = self.quantization_block(x) + if self.att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']: + x1 = self.attention_block(x) + x2 = self.attention_block2(x) + x3 = self.attention_block3(x) + x = torch.cat([x1, x2, x3], axis=1) + else: + x = self.attention_block(x) + x = x.mean(-1) x = self.classifier(x) return x @@ -298,13 +314,24 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr # tnbof block in_channels, series_length = self.compute_intermediate_dimensions(in_channels, series_length) self.quantization_block = TNBoF(in_channels, n_codeword) + out_dim = n_codeword * 2 # attention block - self.short_attention_block = Attention(n_codeword, series_length - int(series_length / 2), att_type) - self.long_attention_block = Attention(n_codeword, series_length, att_type) + self.att_type = att_type + if att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']: + out_dim = out_dim * 3 + self.short_attention_block = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type) + self.long_attention_block = SelfAttention(n_codeword, series_length, att_type) + self.short_attention_block2 = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type) + self.long_attention_block2 = SelfAttention(n_codeword, series_length, att_type) + self.short_attention_block3 = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type) + self.long_attention_block3 = SelfAttention(n_codeword, series_length, att_type) + else: + self.short_attention_block = Attention(n_codeword, series_length - int(series_length / 2), att_type) + self.long_attention_block = Attention(n_codeword, series_length, att_type) # classifier - self.classifier = nn.Sequential(nn.Linear(in_features=n_codeword * 2, out_features=512), + self.classifier = nn.Sequential(nn.Linear(in_features=out_dim, out_features=512), nn.ReLU(), nn.Dropout(dropout), nn.Linear(in_features=512, out_features=n_class)) @@ -312,8 +339,20 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr def forward(self, x): x = self.resnet_block(x) x_short, x_long = self.quantization_block(x) - x_short = self.short_attention_block(x_short).mean(-1) - x_long = self.long_attention_block(x_long).mean(-1) + if self.att_type in ['spatialsa', 'temporalsa', 'spatiotemporal']: + x_short1 = self.short_attention_block(x_short) + x_long1 = self.long_attention_block(x_long) + x_short2 = self.short_attention_block2(x_short) + x_long2 = self.long_attention_block2(x_long) + x_short3 = self.short_attention_block3(x_short) + x_long3 = self.long_attention_block3(x_long) + x_short = torch.cat([x_short1, x_short2, x_short3], axis=1) + x_long = torch.cat([x_long1, x_long2, x_long3], axis=1) + else: + x_short = self.short_attention_block(x_short) + x_long = self.long_attention_block(x_long) + x_short = x_short.mean(-1) + x_long = x_long.mean(-1) x = torch.cat([x_short, x_long], dim=-1) x = self.classifier(x) return x diff --git a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/samodels.py b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/samodels.py new file mode 100644 index 0000000000..de9436af7d --- /dev/null +++ b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/algorithm/samodels.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SelfAttention(nn.Module): + def __init__(self, n_codeword, series_length, att_type): + super(SelfAttention, self).__init__() + + assert att_type in ['spatialsa', 'temporalsa', 'spatiotemporal'] + + self.att_type = att_type + self.hidden_dim = 128 + + self.n_codeword = n_codeword + self.series_length = series_length + + if self.att_type == 'spatiotemporal': + self.w_s = nn.Linear(n_codeword, self.hidden_dim) + self.w_t = nn.Linear(series_length, self.hidden_dim) + elif self.att_type == 'spatialsa': + self.w_1 = nn.Linear(series_length, self.hidden_dim) + self.w_2 = nn.Linear(series_length, self.hidden_dim) + elif self.att_type == 'temporalsa': + self.w_1 = nn.Linear(n_codeword, self.hidden_dim) + self.w_2 = nn.Linear(n_codeword, self.hidden_dim) + self.drop = nn.Dropout(0.2) + self.alpha = nn.Parameter(data=torch.Tensor(1), requires_grad=True) + + def forward(self, x): + # dimension order of x: batch_size, in_channels, series_length + + # clip the value of alpha to [0, 1] + with torch.no_grad(): + self.alpha.copy_(torch.clip(self.alpha, 0.0, 1.0)) + + if self.att_type == 'spatiotemporal': + q = self.w_t(x) + x_s = x.transpose(-1, -2) + k = self.w_s(x_s) + qkt = q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5) + mask = F.sigmoid(qkt) + x = x * self.alpha + (1.0 - self.alpha) * x * mask + + elif self.att_type == 'temporalsa': + x1 = x.transpose(-1, -2) + q = self.w_1(x1) + k = self.w_2(x1) + mask = F.softmax(q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5), dim=-1) + mask = self.drop(mask) + temp = mask @ x1 + x1 = x1 * self.alpha + (1.0 - self.alpha) * temp + x = x1.transpose(-2, -1) + + elif self.att_type == 'spatialsa': + q = self.w_1(x) + k = self.w_2(x) + mask = F.softmax(q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5), dim=-1) + mask = self.drop(mask) + temp = mask @ x + x = x * self.alpha + (1.0 - self.alpha) * temp + + return x diff --git a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/attention_neural_bag_of_feature_learner.py b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/attention_neural_bag_of_feature_learner.py index 9af30cd277..e1d7df78e4 100644 --- a/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/attention_neural_bag_of_feature_learner.py +++ b/src/opendr/perception/heart_anomaly_detection/attention_neural_bag_of_feature/attention_neural_bag_of_feature_learner.py @@ -53,6 +53,7 @@ PRETRAINED_N_CODEWORD = [256, 512] PRETRAINED_QUANT_TYPE = ['nbof', ] PRETRAINED_ATTENTION_TYPE = ['temporal', ] +PRETRAINED_SA_MODELS = ['AF_nbof_temporalsa_0_30_256'] AF_SAMPLING_RATE = 300 @@ -66,7 +67,7 @@ def __init__(self, attention_type='spatial', lr_scheduler=get_cosine_lr_scheduler(1e-3, 1e-5), optimizer='adam', - weight_decay=0.0, + weight_decay=0, dropout=0.2, iters=300, batch_size=32, @@ -87,8 +88,8 @@ def __init__(self, 'Parameter `quantization_type` must be "nbof" or "tnbof"\n' +\ 'Provided value: {}'.format(quantization_type) - assert attention_type in ['spatial', 'temporal'],\ - 'Parameter `attention_type` must be "spatial" or "temporal"\n' +\ + assert attention_type in ['spatial', 'temporal', 'spatiotemporal', 'spatialsa', 'temporalsa'],\ + 'Parameter `attention_type` must be "spatial", "temporal", "spatiotemporal", "spatialsa", or "temporalsa"\n' +\ 'Provided value: {}'.format(attention_type) assert checkpoint_load_iter < iters,\ @@ -496,39 +497,7 @@ def download(self, path, fold_idx): 'Only support pretrained model for the AF dataset, which has 4 classes.\n' +\ 'Current model specification has {} classes'.format(self.n_class) - assert fold_idx in [0, 1, 2, 3, 4],\ - '`fold_idx` must receive a value from the list [0, 1, 2, 3, 4]\n' +\ - 'provided value: {}'.format(fold_idx) - sample_length = int(self.series_length / AF_SAMPLING_RATE) - assert sample_length in PRETRAINED_SAMPLE_LENGTH,\ - 'Current `series_length` does not match supported `series_length`.' +\ - 'Supported values of `series_length` includes\n' +\ - '\n'.join([str(v * AF_SAMPLING_RATE) for v in PRETRAINED_SAMPLE_LENGTH]) - - assert self.in_channels == 1,\ - 'The value of `in_channels` parameter must be 1.\n' +\ - 'Provided value of `in_channels`: {}'.format(self.in_channels) - - assert self.n_codeword in PRETRAINED_N_CODEWORD,\ - 'Current `n_codeword` does not match supported `n_codeword`.' +\ - 'Supported values of `n_codeword` includes\n' +\ - '\n'.join([str(v) for v in PRETRAINED_N_CODEWORD]) - - assert self.quantization_type in PRETRAINED_QUANT_TYPE,\ - 'Current `quantization_type` does not match supported `quantization_type`.' +\ - 'Supported values of `quantization_type` includes\n' +\ - '\n'.join([str(v) for v in PRETRAINED_QUANT_TYPE]) - - assert self.attention_type in PRETRAINED_ATTENTION_TYPE,\ - 'Current `attention_type` does not match supported `attention_type`.' +\ - 'Supported values of `attention_type` includes\n' +\ - '\n'.join([str(v) for v in PRETRAINED_ATTENTION_TYPE]) - - server_url = os.path.join(OPENDR_SERVER_URL, - 'perception', - 'heart_anomaly_detection', - 'attention_neural_bag_of_feature') model_name = 'AF_{}_{}_{}_{}_{}'.format(self.quantization_type, self.attention_type, @@ -536,6 +505,44 @@ def download(self, path, fold_idx): sample_length, self.n_codeword) + if self.attention_type in ['temporalsa', 'spatialsa', 'spatiotemporal']: + assert model_name in PRETRAINED_SA_MODELS,\ + 'Current configuration does not match any pre-trained model.' +\ + 'Available self-attention models: {}'.format(PRETRAINED_SA_MODELS) + else: + assert fold_idx in [0, 1, 2, 3, 4],\ + '`fold_idx` must receive a value from the list [0, 1, 2, 3, 4]\n' +\ + 'provided value: {}'.format(fold_idx) + + assert sample_length in PRETRAINED_SAMPLE_LENGTH,\ + 'Current `series_length` does not match supported `series_length`.' +\ + 'Supported values of `series_length` includes\n' +\ + '\n'.join([str(v * AF_SAMPLING_RATE) for v in PRETRAINED_SAMPLE_LENGTH]) + + assert self.in_channels == 1,\ + 'The value of `in_channels` parameter must be 1.\n' +\ + 'Provided value of `in_channels`: {}'.format(self.in_channels) + + assert self.n_codeword in PRETRAINED_N_CODEWORD,\ + 'Current `n_codeword` does not match supported `n_codeword`.' +\ + 'Supported values of `n_codeword` includes\n' +\ + '\n'.join([str(v) for v in PRETRAINED_N_CODEWORD]) + + assert self.quantization_type in PRETRAINED_QUANT_TYPE,\ + 'Current `quantization_type` does not match supported `quantization_type`.' +\ + 'Supported values of `quantization_type` includes\n' +\ + '\n'.join([str(v) for v in PRETRAINED_QUANT_TYPE]) + + assert self.attention_type in PRETRAINED_ATTENTION_TYPE,\ + 'Current `attention_type` does not match supported `attention_type`.' +\ + 'Supported values of `attention_type` includes\n' +\ + '\n'.join([str(v) for v in PRETRAINED_ATTENTION_TYPE]) + + server_url = os.path.join(OPENDR_SERVER_URL, + 'perception', + 'heart_anomaly_detection', + 'attention_neural_bag_of_feature') + metadata_url = os.path.join(server_url, '{}.json'.format(model_name)) metadata_file = os.path.join(path, 'metadata.json') urlretrieve(metadata_url, metadata_file) From c3e41d65b6107284f21277427077b05414b461c3 Mon Sep 17 00:00:00 2001 From: katerynaCh Date: Mon, 18 Apr 2022 19:33:23 +0300 Subject: [PATCH 2/2] added attention models to ci test --- .../test_attention_neural_bag_of_feature_learner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/sources/tools/perception/heart_anomaly_detection/attention_neural_bag_of_feature/test_attention_neural_bag_of_feature_learner.py b/tests/sources/tools/perception/heart_anomaly_detection/attention_neural_bag_of_feature/test_attention_neural_bag_of_feature_learner.py index cb0dc1d2f3..ff0e84be36 100644 --- a/tests/sources/tools/perception/heart_anomaly_detection/attention_neural_bag_of_feature/test_attention_neural_bag_of_feature_learner.py +++ b/tests/sources/tools/perception/heart_anomaly_detection/attention_neural_bag_of_feature/test_attention_neural_bag_of_feature_learner.py @@ -58,7 +58,7 @@ def test_fit(self): series_length = random.choice([30 * 300, 40 * 300]) n_class = np.random.randint(low=2, high=100) quantization_type = random.choice(['nbof', 'tnbof']) - attention_type = random.choice(['spatial', 'temporal']) + attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal']) train_set = DummyDataset(in_channels, series_length, n_class) val_set = DummyDataset(in_channels, series_length, n_class) @@ -85,7 +85,7 @@ def test_eval(self): series_length = random.choice([30 * 300, 40 * 300]) n_class = np.random.randint(low=2, high=100) quantization_type = random.choice(['nbof', 'tnbof']) - attention_type = random.choice(['spatial', 'temporal']) + attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal']) learner = AttentionNeuralBagOfFeatureLearner(in_channels, series_length, @@ -110,7 +110,7 @@ def test_infer(self): series_length = random.choice([30 * 300, 40 * 300]) n_class = np.random.randint(low=2, high=100) quantization_type = random.choice(['nbof', 'tnbof']) - attention_type = random.choice(['spatial', 'temporal']) + attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal']) learner = AttentionNeuralBagOfFeatureLearner(in_channels, series_length, @@ -132,7 +132,7 @@ def test_save_load(self): series_length = random.choice([30 * 300, 40 * 300]) n_class = np.random.randint(low=2, high=100) quantization_type = random.choice(['nbof', 'tnbof']) - attention_type = random.choice(['spatial', 'temporal']) + attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal']) learner = AttentionNeuralBagOfFeatureLearner(in_channels, series_length,