diff --git a/llama/generation.py b/llama/generation.py index e1f8aca2..a5b3c7b1 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -19,12 +19,12 @@ from llama.model import ModelArgs, Transformer from llama.tokenizer import Tokenizer -if torch.backends.mps.is_available(): - device = torch.device('mps') -elif torch.cuda.is_available(): - device = torch.device('cuda') +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" else: - device = torch.device('cpu') + device = "cpu" Role = Literal["system", "user", "assistant"] @@ -85,7 +85,6 @@ def build( if device == "cuda": torch.cuda.set_device(local_rank) - # seed must be the same in all processes torch.manual_seed(1) diff --git a/llama/model.py b/llama/model.py index 8f72daa8..90b535b6 100755 --- a/llama/model.py +++ b/llama/model.py @@ -15,13 +15,12 @@ ) from torch import nn -if torch.backends.mps.is_available(): - device = torch.device('mps') -elif torch.cuda.is_available(): - device = torch.device('cuda') +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" else: - device = torch.device('cpu') - + device = "cpu" @dataclass class ModelArgs: @@ -81,8 +80,6 @@ def apply_rotary_emb( xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - if not torch.cuda.is_available(): - freqs_cis = freqs_cis.to('cpu') xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device) @@ -97,7 +94,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) #.to(device) + ) class Attention(nn.Module): @@ -287,6 +284,7 @@ def __init__(self, params: ModelArgs): def forward(self, tokens: torch.Tensor, start_pos: int): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to("cuda" if device == "cuda" else "cpu") freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] mask = None