Skip to content

Commit

Permalink
Merge pull request #127 from facebookresearch/fixes
Browse files Browse the repository at this point in the history
Fix bugs introduced in #18
  • Loading branch information
mpu authored Sep 22, 2023
2 parents 2f0f7bb + 3008347 commit 077f733
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
11 changes: 5 additions & 6 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = torch.device('cpu')
device = "cpu"

Role = Literal["system", "user", "assistant"]

Expand Down Expand Up @@ -85,7 +85,6 @@ def build(
if device == "cuda":
torch.cuda.set_device(local_rank)


# seed must be the same in all processes
torch.manual_seed(1)

Expand Down
16 changes: 7 additions & 9 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = torch.device('cpu')

device = "cpu"

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -81,8 +80,6 @@ def apply_rotary_emb(
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
if not torch.cuda.is_available():
freqs_cis = freqs_cis.to('cpu')
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device)
Expand All @@ -97,7 +94,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
) #.to(device)
)


class Attention(nn.Module):
Expand Down Expand Up @@ -287,6 +284,7 @@ def __init__(self, params: ModelArgs):
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to("cuda" if device == "cuda" else "cpu")
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
Expand Down

0 comments on commit 077f733

Please sign in to comment.