Skip to content

Commit

Permalink
fix & add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
kuacakuaca committed Sep 5, 2024
1 parent c0c706e commit 7f4decd
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
6 changes: 5 additions & 1 deletion i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def _sinusoidal_pe(pos_seq: torch.Tensor, embed_dim: int):
inv_freq = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0, device=pos_seq.device) / embed_dim))

sinusoid_input = torch.outer(pos_seq, inv_freq)
pos_emb = torch.cat([sinusoid_input.sin(), sinusoid_input.cos()], dim=-1) # [num. positions, embed_dim]

pos_emb = torch.zeros(pos_seq.shape[0], embed_dim)

pos_emb[:, 0::2] = sinusoid_input.sin()
pos_emb[:, 1::2] = sinusoid_input.cos()

return pos_emb
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
onnx
onnxruntime
onnxruntime
espnet
74 changes: 74 additions & 0 deletions tests/test_conformer_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,77 @@ def get_output_shape(
with_linear_pos=with_linear_pos,
separate_pos_emb_per_head=separate_pos_emb_per_head,
) == [4, 15, 32]


def test_ConformerMHSARelPosV1_against_Espnet():
from espnet2.asr_transducer.encoder.modules.attention import RelPositionMultiHeadedAttention
from espnet2.asr_transducer.encoder.modules.positional_encoding import RelPositionalEncoding

num_heads = 4
embed_size = 256
dropout_rate = 0.1
batch_dim_size = 4
time_dim_size = 50

espnet_mhsa_module = RelPositionMultiHeadedAttention(
num_heads=num_heads, embed_size=embed_size, dropout_rate=dropout_rate
)
espnet_mhsa_module.eval()
espnet_pos_enc_module = RelPositionalEncoding(embed_size, dropout_rate=dropout_rate)
espnet_pos_enc_module.eval()

cfg = ConformerMHSARelPosV1Config(
input_dim=embed_size,
num_att_heads=num_heads,
with_bias=True,
att_weights_dropout=dropout_rate,
dropout=dropout_rate,
learnable_pos_emb=False,
with_linear_pos=True,
separate_pos_emb_per_head=True,
rel_pos_clip=None,
with_pos_bias=True,
pos_emb_dropout=dropout_rate,
dropout_broadcast_axes=None,
)
own_mhsa_module = ConformerMHSARelPosV1(cfg)
own_mhsa_module.eval()
own_mhsa_module.linear_pos = espnet_mhsa_module.linear_pos
own_mhsa_module.pos_bias_u = espnet_mhsa_module.pos_bias_u
own_mhsa_module.pos_bias_v = espnet_mhsa_module.pos_bias_v
own_mhsa_module.out_proj = espnet_mhsa_module.linear_out
own_mhsa_module.qkv_proj.weight = nn.Parameter(
torch.cat(
[
espnet_mhsa_module.linear_q.weight,
espnet_mhsa_module.linear_k.weight,
espnet_mhsa_module.linear_v.weight,
],
dim=0,
)
)
own_mhsa_module.qkv_proj.bias = nn.Parameter(
torch.cat(
[espnet_mhsa_module.linear_q.bias, espnet_mhsa_module.linear_k.bias, espnet_mhsa_module.linear_v.bias],
dim=0,
)
)

input_tensor = torch.rand((batch_dim_size, time_dim_size, embed_size))
sequence_mask = torch.ones((batch_dim_size, time_dim_size))
inv_sequence_mask = torch.logical_not(sequence_mask)

input_tensor_layernorm = own_mhsa_module.layernorm(input_tensor)

espnet_pos_enc = espnet_pos_enc_module(input_tensor_layernorm)
espnet_output_tensor = espnet_mhsa_module(
query=input_tensor_layernorm,
key=input_tensor_layernorm,
value=input_tensor_layernorm,
pos_enc=espnet_pos_enc,
mask=inv_sequence_mask,
)

own_output_tensor = own_mhsa_module(input_tensor, sequence_mask=sequence_mask)

assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03)

0 comments on commit 7f4decd

Please sign in to comment.