Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of heart anomaly detection self-attention neural bag of features #246

Merged
merged 3 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/reference/attention-neural-bag-of-feature-learner.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ AttentionNeuralBagOfFeatureLearner(self, in_channels, series_length, n_class, n_
- **quantization_type**: *{"nbof", "tnbof"}, default="nbof"*
Specifies the type of quantization layer.
There are two types of quantization layer: the logistic neural bag-of-feature layer ("nbof") or the temporal logistic bag-of-feature layer ("tnbof").
- **attention_type**: *{"spatial", "temporal"}, default="spatial"*
- **attention_type**: *{"spatial", "temporal", "spatialsa", "temporalsa", "spatiotemporal"}, default="spatial"*
Specifies the type of attention mechanism.
There are two types of attention: the spatial attention mechanism that focuses on the different codewords ("spatial") or the temporal attention mechanism that focuses on different temporal instances ("temporal").
There are two types of attention: the spatial attention mechanism that focuses on the different codewords ("spatial") or the temporal attention mechanism that focuses on different temporal instances ("temporal"). Additionaly, there are three self-attention based mecahnisms as described [here](https://arxiv.org/abs/2201.11092).
- **lr_scheduler**: *callable, default=`opendr.perception.heart_anomaly_detection.attention_neural_bag_of_feature.attention_neural_bag_of_feature_learner.get_cosine_lr_scheduler(2e-4, 1e-5)`*
Specifies the function that computes the learning rate, given the total number of epochs `n_epoch` and the current epoch index `epoch_idx`.
That is, the optimizer uses this function to determine the learning rate at a given epoch index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ This module provides the implementation of the Attention Neural Bag-of-Features
## Sources

The algorithm is implemented according to the paper [Attention-based Neural Bag-of-Features Learning For Sequence Data](https://arxiv.org/abs/2005.12250).
Additionally, three self-attention mechanisms as described in [Self-Attention Neural Bag-of-Features](https://arxiv.org/abs/2201.11092) are implemented.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opendr.perception.heart_anomaly_detection.attention_neural_bag_of_feature.algorithm.samodels import SelfAttention


class ResidualBlock(nn.Module):
Expand Down Expand Up @@ -257,19 +258,34 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr
# nbof block
in_channels, series_length = self.compute_intermediate_dimensions(in_channels, series_length)
self.quantization_block = NBoF(in_channels, n_codeword)

self.att_type = att_type
out_dim = n_codeword
# attention block
self.attention_block = Attention(n_codeword, series_length, att_type)
if att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']:
self.attention_block = SelfAttention(n_codeword, series_length, att_type)
self.attention_block2 = SelfAttention(n_codeword, series_length, att_type)
self.attention_block3 = SelfAttention(n_codeword, series_length, att_type)
out_dim = n_codeword*3
else:
self.attention_block = Attention(n_codeword, series_length, att_type)

# classifier
self.classifier = nn.Sequential(nn.Linear(in_features=n_codeword, out_features=512),
self.classifier = nn.Sequential(nn.Linear(in_features=out_dim, out_features=512),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(in_features=512, out_features=n_class))

def forward(self, x):
x = self.resnet_block(x)
x = self.attention_block(self.quantization_block(x)).mean(-1)
x = self.quantization_block(x)
if self.att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']:
x1 = self.attention_block(x)
x2 = self.attention_block2(x)
x3 = self.attention_block3(x)
x = torch.cat([x1, x2, x3], axis=1)
else:
x = self.attention_block(x)
x = x.mean(-1)
x = self.classifier(x)
return x

Expand Down Expand Up @@ -298,22 +314,45 @@ def __init__(self, in_channels, series_length, n_codeword, att_type, n_class, dr
# tnbof block
in_channels, series_length = self.compute_intermediate_dimensions(in_channels, series_length)
self.quantization_block = TNBoF(in_channels, n_codeword)
out_dim = n_codeword * 2

# attention block
self.short_attention_block = Attention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block = Attention(n_codeword, series_length, att_type)
self.att_type = att_type
if att_type in ['spatiotemporal', 'spatialsa', 'temporalsa']:
out_dim = out_dim * 3
self.short_attention_block = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block = SelfAttention(n_codeword, series_length, att_type)
self.short_attention_block2 = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block2 = SelfAttention(n_codeword, series_length, att_type)
self.short_attention_block3 = SelfAttention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block3 = SelfAttention(n_codeword, series_length, att_type)
else:
self.short_attention_block = Attention(n_codeword, series_length - int(series_length / 2), att_type)
self.long_attention_block = Attention(n_codeword, series_length, att_type)

# classifier
self.classifier = nn.Sequential(nn.Linear(in_features=n_codeword * 2, out_features=512),
self.classifier = nn.Sequential(nn.Linear(in_features=out_dim, out_features=512),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(in_features=512, out_features=n_class))

def forward(self, x):
x = self.resnet_block(x)
x_short, x_long = self.quantization_block(x)
x_short = self.short_attention_block(x_short).mean(-1)
x_long = self.long_attention_block(x_long).mean(-1)
if self.att_type in ['spatialsa', 'temporalsa', 'spatiotemporal']:
x_short1 = self.short_attention_block(x_short)
x_long1 = self.long_attention_block(x_long)
x_short2 = self.short_attention_block2(x_short)
x_long2 = self.long_attention_block2(x_long)
x_short3 = self.short_attention_block3(x_short)
x_long3 = self.long_attention_block3(x_long)
x_short = torch.cat([x_short1, x_short2, x_short3], axis=1)
x_long = torch.cat([x_long1, x_long2, x_long3], axis=1)
else:
x_short = self.short_attention_block(x_short)
x_long = self.long_attention_block(x_long)
x_short = x_short.mean(-1)
x_long = x_long.mean(-1)
x = torch.cat([x_short, x_long], dim=-1)
x = self.classifier(x)
return x
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
def __init__(self, n_codeword, series_length, att_type):
super(SelfAttention, self).__init__()

assert att_type in ['spatialsa', 'temporalsa', 'spatiotemporal']

self.att_type = att_type
self.hidden_dim = 128

self.n_codeword = n_codeword
self.series_length = series_length

if self.att_type == 'spatiotemporal':
self.w_s = nn.Linear(n_codeword, self.hidden_dim)
self.w_t = nn.Linear(series_length, self.hidden_dim)
elif self.att_type == 'spatialsa':
self.w_1 = nn.Linear(series_length, self.hidden_dim)
self.w_2 = nn.Linear(series_length, self.hidden_dim)
elif self.att_type == 'temporalsa':
self.w_1 = nn.Linear(n_codeword, self.hidden_dim)
self.w_2 = nn.Linear(n_codeword, self.hidden_dim)
self.drop = nn.Dropout(0.2)
self.alpha = nn.Parameter(data=torch.Tensor(1), requires_grad=True)

def forward(self, x):
# dimension order of x: batch_size, in_channels, series_length

# clip the value of alpha to [0, 1]
with torch.no_grad():
self.alpha.copy_(torch.clip(self.alpha, 0.0, 1.0))

if self.att_type == 'spatiotemporal':
q = self.w_t(x)
x_s = x.transpose(-1, -2)
k = self.w_s(x_s)
qkt = q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5)
mask = F.sigmoid(qkt)
x = x * self.alpha + (1.0 - self.alpha) * x * mask

elif self.att_type == 'temporalsa':
x1 = x.transpose(-1, -2)
q = self.w_1(x1)
k = self.w_2(x1)
mask = F.softmax(q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5), dim=-1)
mask = self.drop(mask)
temp = mask @ x1
x1 = x1 * self.alpha + (1.0 - self.alpha) * temp
x = x1.transpose(-2, -1)

elif self.att_type == 'spatialsa':
q = self.w_1(x)
k = self.w_2(x)
mask = F.softmax(q @ k.transpose(-2, -1)*(self.hidden_dim**-0.5), dim=-1)
mask = self.drop(mask)
temp = mask @ x
x = x * self.alpha + (1.0 - self.alpha) * temp

return x
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
PRETRAINED_N_CODEWORD = [256, 512]
PRETRAINED_QUANT_TYPE = ['nbof', ]
PRETRAINED_ATTENTION_TYPE = ['temporal', ]
PRETRAINED_SA_MODELS = ['AF_nbof_temporalsa_0_30_256']
AF_SAMPLING_RATE = 300


Expand All @@ -66,7 +67,7 @@ def __init__(self,
attention_type='spatial',
lr_scheduler=get_cosine_lr_scheduler(1e-3, 1e-5),
optimizer='adam',
weight_decay=0.0,
weight_decay=0,
dropout=0.2,
iters=300,
batch_size=32,
Expand All @@ -87,8 +88,8 @@ def __init__(self,
'Parameter `quantization_type` must be "nbof" or "tnbof"\n' +\
'Provided value: {}'.format(quantization_type)

assert attention_type in ['spatial', 'temporal'],\
'Parameter `attention_type` must be "spatial" or "temporal"\n' +\
assert attention_type in ['spatial', 'temporal', 'spatiotemporal', 'spatialsa', 'temporalsa'],\
'Parameter `attention_type` must be "spatial", "temporal", "spatiotemporal", "spatialsa", or "temporalsa"\n' +\
'Provided value: {}'.format(attention_type)

assert checkpoint_load_iter < iters,\
Expand Down Expand Up @@ -496,46 +497,52 @@ def download(self, path, fold_idx):
'Only support pretrained model for the AF dataset, which has 4 classes.\n' +\
'Current model specification has {} classes'.format(self.n_class)

assert fold_idx in [0, 1, 2, 3, 4],\
'`fold_idx` must receive a value from the list [0, 1, 2, 3, 4]\n' +\
'provided value: {}'.format(fold_idx)

sample_length = int(self.series_length / AF_SAMPLING_RATE)
assert sample_length in PRETRAINED_SAMPLE_LENGTH,\
'Current `series_length` does not match supported `series_length`.' +\
'Supported values of `series_length` includes\n' +\
'\n'.join([str(v * AF_SAMPLING_RATE) for v in PRETRAINED_SAMPLE_LENGTH])

assert self.in_channels == 1,\
'The value of `in_channels` parameter must be 1.\n' +\
'Provided value of `in_channels`: {}'.format(self.in_channels)

assert self.n_codeword in PRETRAINED_N_CODEWORD,\
'Current `n_codeword` does not match supported `n_codeword`.' +\
'Supported values of `n_codeword` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_N_CODEWORD])

assert self.quantization_type in PRETRAINED_QUANT_TYPE,\
'Current `quantization_type` does not match supported `quantization_type`.' +\
'Supported values of `quantization_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_QUANT_TYPE])

assert self.attention_type in PRETRAINED_ATTENTION_TYPE,\
'Current `attention_type` does not match supported `attention_type`.' +\
'Supported values of `attention_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_ATTENTION_TYPE])

server_url = os.path.join(OPENDR_SERVER_URL,
'perception',
'heart_anomaly_detection',
'attention_neural_bag_of_feature')

model_name = 'AF_{}_{}_{}_{}_{}'.format(self.quantization_type,
self.attention_type,
fold_idx,
sample_length,
self.n_codeword)

if self.attention_type in ['temporalsa', 'spatialsa', 'spatiotemporal']:
assert model_name in PRETRAINED_SA_MODELS,\
'Current configuration does not match any pre-trained model.' +\
'Available self-attention models: {}'.format(PRETRAINED_SA_MODELS)
else:
assert fold_idx in [0, 1, 2, 3, 4],\
'`fold_idx` must receive a value from the list [0, 1, 2, 3, 4]\n' +\
'provided value: {}'.format(fold_idx)

assert sample_length in PRETRAINED_SAMPLE_LENGTH,\
'Current `series_length` does not match supported `series_length`.' +\
'Supported values of `series_length` includes\n' +\
'\n'.join([str(v * AF_SAMPLING_RATE) for v in PRETRAINED_SAMPLE_LENGTH])

assert self.in_channels == 1,\
'The value of `in_channels` parameter must be 1.\n' +\
'Provided value of `in_channels`: {}'.format(self.in_channels)

assert self.n_codeword in PRETRAINED_N_CODEWORD,\
'Current `n_codeword` does not match supported `n_codeword`.' +\
'Supported values of `n_codeword` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_N_CODEWORD])

assert self.quantization_type in PRETRAINED_QUANT_TYPE,\
'Current `quantization_type` does not match supported `quantization_type`.' +\
'Supported values of `quantization_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_QUANT_TYPE])

assert self.attention_type in PRETRAINED_ATTENTION_TYPE,\
'Current `attention_type` does not match supported `attention_type`.' +\
'Supported values of `attention_type` includes\n' +\
'\n'.join([str(v) for v in PRETRAINED_ATTENTION_TYPE])

server_url = os.path.join(OPENDR_SERVER_URL,
'perception',
'heart_anomaly_detection',
'attention_neural_bag_of_feature')

metadata_url = os.path.join(server_url, '{}.json'.format(model_name))
metadata_file = os.path.join(path, 'metadata.json')
urlretrieve(metadata_url, metadata_file)
Expand Down