Skip to content

Commit

Permalink
Integration of heart anomaly detection self-attention neural bag of f…
Browse files Browse the repository at this point in the history
…eatures (#246)

* added sanbof models

* added attention models to ci test

Co-authored-by: ad-daniel <[email protected]>
  • Loading branch information
katerynaCh and ad-daniel authored Apr 25, 2022
1 parent a360896 commit 5d235af
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 50 deletions.
4 changes: 2 additions & 2 deletions docs/reference/attention-neural-bag-of-feature-learner.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ AttentionNeuralBagOfFeatureLearner(self, in_channels, series_length, n_class, n_
- **quantization_type**: *{"nbof", "tnbof"}, default="nbof"*
Specifies the type of quantization layer.
There are two types of quantization layer: the logistic neural bag-of-feature layer ("nbof") or the temporal logistic bag-of-feature layer ("tnbof").
- **attention_type**: *{"spatial", "temporal"}, default="spatial"*
- **attention_type**: *{"spatial", "temporal", "spatialsa", "temporalsa", "spatiotemporal"}, default="spatial"*
Specifies the type of attention mechanism.
There are two types of attention: the spatial attention mechanism that focuses on the different codewords ("spatial") or the temporal attention mechanism that focuses on different temporal instances ("temporal").
There are two types of attention: the spatial attention mechanism that focuses on the different codewords ("spatial") or the temporal attention mechanism that focuses on different temporal instances ("temporal"). Additionaly, there are three self-attention based mecahnisms as described [here](https://arxiv.org/abs/2201.11092).
- **lr_scheduler**: *callable, default=`opendr.perception.heart_anomaly_detection.attention_neural_bag_of_feature.attention_neural_bag_of_feature_learner.get_cosine_lr_scheduler(2e-4, 1e-5)`*
Specifies the function that computes the learning rate, given the total number of epochs `n_epoch` and the current epoch index `epoch_idx`.
That is, the optimizer uses this function to determine the learning rate at a given epoch index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ This module provides the implementation of the Attention Neural Bag-of-Features
## Sources

The algorithm is implemented according to the paper [Attention-based Neural Bag-of-Features Learning For Sequence Data](https://arxiv.org/abs/2005.12250).
Additionally, three self-attention mechanisms as described in [Self-Attention Neural Bag-of-Features](https://arxiv.org/abs/2201.11092) are implemented.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opendr.perception.heart_anomaly_detection.attention_neural_bag_of_feature.algorithm.samodels import SelfAttention


class ResidualBlock(nn.Module):
Expand Down Expand Up @@ -257,19 +258,34 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr
# nbof block
in_channels, series_length = self.compute_intermediate_dimensions(in_channels, series_length)
self.quantization_block = NBoF(in_channels, n_codeword)

self.att_type = att_type
out_dim = n_codeword
# attention block
self.attention_block = Attention(n_codeword, series_length, att_type)
if att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']:
self.attention_block = SelfAttention(n_codeword, series_length, att_type)
self.attention_block2 = SelfAttention(n_codeword, series_length, att_type)
self.attention_block3 = SelfAttention(n_codeword, series_length, att_type)
out_dim = n_codeword*3
else:
self.attention_block = Attention(n_codeword, series_length, att_type)

# classifier
self.classifier = nn.Sequential(nn.Linear(in_features=n_codeword, out_features=512),
self.classifier = nn.Sequential(nn.Linear(in_features=out_dim, out_features=512),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(in_features=512, out_features=n_class))

def forward(self, x):
x = self.resnet_block(x)
x = self.attention_block(self.quantization_block(x)).mean(-1)
x = self.quantization_block(x)
if self.att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']:
x1 = self.attention_block(x)
x2 = self.attention_block2(x)
x3 = self.attention_block3(x)
x = torch.cat([x1, x2, x3], axis=1)
else:
x = self.attention_block(x)
x = x.mean(-1)
x = self.classifier(x)
return x

Expand Down Expand Up @@ -298,22 +314,45 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr
# tnbof block
in_channels, series_length = self.compute_intermediate_dimensions(in_channels, series_length)
self.quantization_block = TNBoF(in_channels, n_codeword)
out_dim = n_codeword * 2

# attention block
self.short_attention_block = Attention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block = Attention(n_codeword, series_length, att_type)
self.att_type = att_type
if att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']:
out_dim = out_dim * 3
self.short_attention_block = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block = SelfAttention(n_codeword, series_length, att_type)
self.short_attention_block2 = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block2 = SelfAttention(n_codeword, series_length, att_type)
self.short_attention_block3 = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block3 = SelfAttention(n_codeword, series_length, att_type)
else:
self.short_attention_block = Attention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block = Attention(n_codeword, series_length, att_type)

# classifier
self.classifier = nn.Sequential(nn.Linear(in_features=n_codeword * 2, out_features=512),
self.classifier = nn.Sequential(nn.Linear(in_features=out_dim, out_features=512),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(in_features=512, out_features=n_class))

def forward(self, x):
x = self.resnet_block(x)
x_short, x_long = self.quantization_block(x)
x_short = self.short_attention_block(x_short).mean(-1)
x_long = self.long_attention_block(x_long).mean(-1)
if self.att_type in ['spatialsa', 'temporalsa', 'spatiotemporal']:
x_short1 = self.short_attention_block(x_short)
x_long1 = self.long_attention_block(x_long)
x_short2 = self.short_attention_block2(x_short)
x_long2 = self.long_attention_block2(x_long)
x_short3 = self.short_attention_block3(x_short)
x_long3 = self.long_attention_block3(x_long)
x_short = torch.cat([x_short1, x_short2, x_short3], axis=1)
x_long = torch.cat([x_long1, x_long2, x_long3], axis=1)
else:
x_short = self.short_attention_block(x_short)
x_long = self.long_attention_block(x_long)
x_short = x_short.mean(-1)
x_long = x_long.mean(-1)
x = torch.cat([x_short, x_long], dim=-1)
x = self.classifier(x)
return x
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
def __init__(self, n_codeword, series_length, att_type):
super(SelfAttention, self).__init__()

assert att_type in ['spatialsa', 'temporalsa', 'spatiotemporal']

self.att_type = att_type
self.hidden_dim = 128

self.n_codeword = n_codeword
self.series_length = series_length

if self.att_type == 'spatiotemporal':
self.w_s = nn.Linear(n_codeword, self.hidden_dim)
self.w_t = nn.Linear(series_length, self.hidden_dim)
elif self.att_type == 'spatialsa':
self.w_1 = nn.Linear(series_length, self.hidden_dim)
self.w_2 = nn.Linear(series_length, self.hidden_dim)
elif self.att_type == 'temporalsa':
self.w_1 = nn.Linear(n_codeword, self.hidden_dim)
self.w_2 = nn.Linear(n_codeword, self.hidden_dim)
self.drop = nn.Dropout(0.2)
self.alpha = nn.Parameter(data=torch.Tensor(1), requires_grad=True)

def forward(self, x):
# dimension order of x: batch_size, in_channels, series_length

# clip the value of alpha to [0, 1]
with torch.no_grad():
self.alpha.copy_(torch.clip(self.alpha, 0.0, 1.0))

if self.att_type == 'spatiotemporal':
q = self.w_t(x)
x_s = x.transpose(-1, -2)
k = self.w_s(x_s)
qkt = q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5)
mask = F.sigmoid(qkt)
x = x * self.alpha + (1.0 - self.alpha) * x * mask

elif self.att_type == 'temporalsa':
x1 = x.transpose(-1, -2)
q = self.w_1(x1)
k = self.w_2(x1)
mask = F.softmax(q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5), dim=-1)
mask = self.drop(mask)
temp = mask @ x1
x1 = x1 * self.alpha + (1.0 - self.alpha) * temp
x = x1.transpose(-2, -1)

elif self.att_type == 'spatialsa':
q = self.w_1(x)
k = self.w_2(x)
mask = F.softmax(q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5), dim=-1)
mask = self.drop(mask)
temp = mask @ x
x = x * self.alpha + (1.0 - self.alpha) * temp

return x
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
PRETRAINED_N_CODEWORD = [256, 512]
PRETRAINED_QUANT_TYPE = ['nbof', ]
PRETRAINED_ATTENTION_TYPE = ['temporal', ]
PRETRAINED_SA_MODELS = ['AF_nbof_temporalsa_0_30_256']
AF_SAMPLING_RATE = 300


Expand All @@ -66,7 +67,7 @@ def __init__(self,
attention_type='spatial',
lr_scheduler=get_cosine_lr_scheduler(1e-3, 1e-5),
optimizer='adam',
weight_decay=0.0,
weight_decay=0,
dropout=0.2,
iters=300,
batch_size=32,
Expand All @@ -87,8 +88,8 @@ def __init__(self,
'Parameter `quantization_type` must be "nbof" or "tnbof"\n' +\
'Provided value: {}'.format(quantization_type)

assert attention_type in ['spatial', 'temporal'],\
'Parameter `attention_type` must be "spatial" or "temporal"\n' +\
assert attention_type in ['spatial', 'temporal', 'spatiotemporal', 'spatialsa', 'temporalsa'],\
'Parameter `attention_type` must be "spatial", "temporal", "spatiotemporal", "spatialsa", or "temporalsa"\n' +\
'Provided value: {}'.format(attention_type)

assert checkpoint_load_iter < iters,\
Expand Down Expand Up @@ -496,46 +497,52 @@ def download(self, path, fold_idx):
'Only support pretrained model for the AF dataset, which has 4 classes.\n' +\
'Current model specification has {} classes'.format(self.n_class)

assert fold_idx in [0, 1, 2, 3, 4],\
'`fold_idx` must receive a value from the list [0, 1, 2, 3, 4]\n' +\
'provided value: {}'.format(fold_idx)

sample_length = int(self.series_length / AF_SAMPLING_RATE)
assert sample_length in PRETRAINED_SAMPLE_LENGTH,\
'Current `series_length` does not match supported `series_length`.' +\
'Supported values of `series_length` includes\n' +\
'\n'.join([str(v * AF_SAMPLING_RATE) for v in PRETRAINED_SAMPLE_LENGTH])

assert self.in_channels == 1,\
'The value of `in_channels` parameter must be 1.\n' +\
'Provided value of `in_channels`: {}'.format(self.in_channels)

assert self.n_codeword in PRETRAINED_N_CODEWORD,\
'Current `n_codeword` does not match supported `n_codeword`.' +\
'Supported values of `n_codeword` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_N_CODEWORD])

assert self.quantization_type in PRETRAINED_QUANT_TYPE,\
'Current `quantization_type` does not match supported `quantization_type`.' +\
'Supported values of `quantization_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_QUANT_TYPE])

assert self.attention_type in PRETRAINED_ATTENTION_TYPE,\
'Current `attention_type` does not match supported `attention_type`.' +\
'Supported values of `attention_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_ATTENTION_TYPE])

server_url = os.path.join(OPENDR_SERVER_URL,
'perception',
'heart_anomaly_detection',
'attention_neural_bag_of_feature')

model_name = 'AF_{}_{}_{}_{}_{}'.format(self.quantization_type,
self.attention_type,
fold_idx,
sample_length,
self.n_codeword)

if self.attention_type in ['temporalsa', 'spatialsa', 'spatiotemporal']:
assert model_name in PRETRAINED_SA_MODELS,\
'Current configuration does not match any pre-trained model.' +\
'Available self-attention models: {}'.format(PRETRAINED_SA_MODELS)
else:
assert fold_idx in [0, 1, 2, 3, 4],\
'`fold_idx` must receive a value from the list [0, 1, 2, 3, 4]\n' +\
'provided value: {}'.format(fold_idx)

assert sample_length in PRETRAINED_SAMPLE_LENGTH,\
'Current `series_length` does not match supported `series_length`.' +\
'Supported values of `series_length` includes\n' +\
'\n'.join([str(v * AF_SAMPLING_RATE) for v in PRETRAINED_SAMPLE_LENGTH])

assert self.in_channels == 1,\
'The value of `in_channels` parameter must be 1.\n' +\
'Provided value of `in_channels`: {}'.format(self.in_channels)

assert self.n_codeword in PRETRAINED_N_CODEWORD,\
'Current `n_codeword` does not match supported `n_codeword`.' +\
'Supported values of `n_codeword` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_N_CODEWORD])

assert self.quantization_type in PRETRAINED_QUANT_TYPE,\
'Current `quantization_type` does not match supported `quantization_type`.' +\
'Supported values of `quantization_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_QUANT_TYPE])

assert self.attention_type in PRETRAINED_ATTENTION_TYPE,\
'Current `attention_type` does not match supported `attention_type`.' +\
'Supported values of `attention_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_ATTENTION_TYPE])

server_url = os.path.join(OPENDR_SERVER_URL,
'perception',
'heart_anomaly_detection',
'attention_neural_bag_of_feature')

metadata_url = os.path.join(server_url, '{}.json'.format(model_name))
metadata_file = os.path.join(path, 'metadata.json')
urlretrieve(metadata_url, metadata_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_fit(self):
series_length = random.choice([30 * 300, 40 * 300])
n_class = np.random.randint(low=2, high=100)
quantization_type = random.choice(['nbof', 'tnbof'])
attention_type = random.choice(['spatial', 'temporal'])
attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal'])

train_set = DummyDataset(in_channels, series_length, n_class)
val_set = DummyDataset(in_channels, series_length, n_class)
Expand All @@ -85,7 +85,7 @@ def test_eval(self):
series_length = random.choice([30 * 300, 40 * 300])
n_class = np.random.randint(low=2, high=100)
quantization_type = random.choice(['nbof', 'tnbof'])
attention_type = random.choice(['spatial', 'temporal'])
attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal'])

learner = AttentionNeuralBagOfFeatureLearner(in_channels,
series_length,
Expand All @@ -110,7 +110,7 @@ def test_infer(self):
series_length = random.choice([30 * 300, 40 * 300])
n_class = np.random.randint(low=2, high=100)
quantization_type = random.choice(['nbof', 'tnbof'])
attention_type = random.choice(['spatial', 'temporal'])
attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal'])

learner = AttentionNeuralBagOfFeatureLearner(in_channels,
series_length,
Expand All @@ -132,7 +132,7 @@ def test_save_load(self):
series_length = random.choice([30 * 300, 40 * 300])
n_class = np.random.randint(low=2, high=100)
quantization_type = random.choice(['nbof', 'tnbof'])
attention_type = random.choice(['spatial', 'temporal'])
attention_type = random.choice(['spatial', 'temporal', 'spatialsa', 'temporalsa', 'spatiotemporal'])

learner = AttentionNeuralBagOfFeatureLearner(in_channels,
series_length,
Expand Down

0 comments on commit 5d235af

Please sign in to comment.