From 65513d81a7ffc14c76379b74e6a8fc8199ef1f5e Mon Sep 17 00:00:00 2001 From: Mddct Date: Sun, 14 Apr 2024 18:34:33 +0800 Subject: [PATCH] refactor cache behaviour in training mode (reduce compute cost and memory) --- wenet/transformer/attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 54b76daad..c65d8de2e 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -226,7 +226,7 @@ def forward( # >>> torch.equal(b, c) # True # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True - if cache.size(0) > 0: + if cache.size(0) > 0 and not self.training: key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) @@ -234,7 +234,7 @@ def forward( v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) + new_cache = torch.cat((k, v), dim=-1) if not self.training else cache # for multi query or multi group attention if self.h_kv != self.h: @@ -370,7 +370,7 @@ def forward( # >>> torch.equal(b, c) # True # >>> d = torch.split(a, 2, dim=-1) # >>> torch.equal(d[0], d[1]) # True - if cache.size(0) > 0: + if cache.size(0) > 0 and not self.training: key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) @@ -379,7 +379,7 @@ def forward( # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) + new_cache = torch.cat((k, v), dim=-1) if not self.training else cache # for multi query or multi groups attention if self.h_kv != self.h: @@ -472,7 +472,7 @@ def forward( else: q, k, v = self.forward_qkv(query, key, value) - new_cache = torch.cat((k, v), dim=-1) + new_cache = torch.cat((k, v), dim=-1) if not self.training else cache # for multi query or multi groups attention if self.h_kv != self.h: @@ -563,13 +563,13 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: del pos_emb q, k, v = self.forward_qkv(query, key, value) - if cache.size(0) > 0: + if cache.size(0) > 0 and not self.training: key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) - new_cache = torch.cat((k, v), dim=-1) + new_cache = torch.cat((k, v), dim=-1) if not self.training else cache rel_k = self.rel_k_embed( self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k) @@ -664,13 +664,13 @@ def forward( q = llama_apply_rotary_emb(q, pos_emb) k = llama_apply_rotary_emb(k, pos_emb) # see above - if cache.size(0) > 0: + if cache.size(0) > 0 and not self.training: key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) - new_cache = torch.cat((k, v), dim=-1) + new_cache = torch.cat((k, v), dim=-1) if not self.training else cache if self.h_kv != self.h: k = torch.repeat_interleave(