Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor with einops #1173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
Expand Down Expand Up @@ -125,8 +126,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
return rearrange(freqs_cis, 's d -> 1 s 1 d')


def apply_rotary_emb(
Expand All @@ -153,11 +153,16 @@ def apply_rotary_emb(


"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# xq_.shape
# (B, Seq_Len, H, Head_dim) -> (B, Seq_Len, H, Head_dim // 2, 2)
# -> (B, Seq_Len, H, Head_dim // 2) complex
xq_ = torch.view_as_complex(rearrange(xq.float(), '... (c d) -> ... c d', d=2))
xk_ = torch.view_as_complex(rearrange(xk.float(), '... (c d) -> ... c d', d=2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
# (B, Seq_Len, H, Head_dim/2) -> (B, Seq_Len, H, Head_dim/2, 2)
# -> (B, Seq_Len, H, Head_dim)
xq_out = rearrange(torch.view_as_real(xq_ * freqs_cis), '... c d -> ... (c d)')
xk_out = rearrange(torch.view_as_real(xk_ * freqs_cis), '... c d -> ... (c d)')
return xq_out.type_as(xq), xk_out.type_as(xk)


Expand All @@ -166,11 +171,8 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)

return rearrange(repeat(x, 'b s h d -> b s h r d', r=n_rep), 'b s h r d -> b s (h r) d')


class Attention(nn.Module):
Expand Down Expand Up @@ -273,9 +275,9 @@ def forward(
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = rearrange(xq, 'b s (h d) -> b s h d', h=self.n_local_heads)
xk = rearrange(xk, 'b s (h d) -> b s h d', h=self.n_local_kv_heads)
xv = rearrange(xv, 'b s (h d) -> b s h d', h=self.n_local_kv_heads)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

Expand All @@ -292,15 +294,15 @@ def forward(
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
xq = rearrange(xq, 'b s h d -> b h s d') # (bs, n_local_heads, seqlen, head_dim)
keys = rearrange(keys, 'b s h d -> b h s d') # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = rearrange(values, 'b s h d -> b h s d') # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, rearrange(keys, 'b h s d -> b h d s')) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = rearrange(output, 'b h s d -> b s (h d)').contiguous()
return self.wo(output)


Expand Down Expand Up @@ -491,5 +493,5 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
output = self.output(h).float() # explicitly convert to full precision during inference
return output
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ torch
fairscale
fire
sentencepiece
einops