Skip to content

Commit

Permalink
Changes to run Llama2 on Apple Mac M2 / MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
GD Dev committed Jul 21, 2023
1 parent 6c7fe27 commit 9a5670b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
31 changes: 24 additions & 7 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

Role = Literal["system", "user", "assistant"]


Expand Down Expand Up @@ -59,14 +66,19 @@ def build(
model_parallel_size: Optional[int] = None,
) -> "Llama":
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if device == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

if device == "cuda":
torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(1)
Expand All @@ -92,9 +104,13 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if device == "cuda":
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer)
Expand Down Expand Up @@ -123,14 +139,15 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)

for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
Expand Down
27 changes: 20 additions & 7 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')



@dataclass
class ModelArgs:
Expand Down Expand Up @@ -47,6 +55,8 @@ def forward(self, x):

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = freqs.to(device)

t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
Expand All @@ -66,12 +76,15 @@ def apply_rotary_emb(
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq = xq.to('cpu')
xk = xk.to('cpu')
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis = freqs_cis.to('cpu')
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -132,15 +145,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)

def forward(
self,
Expand Down Expand Up @@ -251,7 +264,7 @@ def __init__(self, params: ModelArgs):
self.n_layers = params.n_layers

self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
params.vocab_size, params.dim, init_method=lambda x: x,
)

self.layers = torch.nn.ModuleList()
Expand All @@ -271,18 +284,18 @@ def __init__(self, params: ModelArgs):
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
#self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
(1, 1, seqlen, seqlen), float("-inf"), device=torch.device('cpu')
)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = layer(h, start_pos, freqs_cis, (mask.to(device) if mask is not None else mask))
h = self.norm(h)
output = self.output(h).float()
return output
3 changes: 2 additions & 1 deletion llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
return t

def decode(self, t: List[int]) -> str:
return self.sp_model.decode(t)
#return self.sp_model.decode(t)
return self.sp_model.decode(list(filter(lambda tk: tk != -1, t)))

2 comments on commit 9a5670b

@mauermbq
Copy link

@mauermbq mauermbq commented on 9a5670b Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this example. After following you, nn my intel mac this raised another exception:

File "/opt/dev/llms/codellama/llama/model.py", line 279, in init
File "/opt/dev/llms/codellama/llama/model.py", line 279, in init
self.freqs_cis = precompute_freqs_cis(
self.freqs_cis = precompute_freqs_cis(
File "/opt/dev/llms/codellama/llama/model.py", line 62, in precompute_freqs_cis
File "/opt/dev/llms/codellama/llama/model.py", line 62, in precompute_freqs_cis
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
NotImplementedError: The operator 'aten::polar.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
NotImplementedError: The operator 'aten::polar.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 53666) of binary: /opt/dev/miniconda3/envs/llama/bin/python3.10

and after setting the env variable:
RuntimeErrorRuntimeError: : ProcessGroupGloo::allgather: invalid tensor type at index 0 (expected TensorOptions(dtype=c10::Half, device=cpu, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)), got TensorOptions(dtype=c10::Half, device=mps:0, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))ProcessGroupGloo::allgather: invalid tensor type at index 0 (expected TensorOptions(dtype=c10::Half, device=cpu, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)), got TensorOptions(dtype=c10::Half, device=mps:0, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))

ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 55764) of binary: /opt/dev/miniconda3/envs/llama/bin/python3.10

@Skeeve4711
Copy link

@Skeeve4711 Skeeve4711 commented on 9a5670b Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you need to install the lastest pytorch from source. They fixed the "aten::polar" stuff apparently today just a few hours ago which is crazy fitting for us.
After that I didn't need to set the variable anymore.
Here's the repo: https://github.com/pytorch/pytorch

Here's the instructions as far as I remember: :D

1. git clone https://github.com/pytorch/pytorch.git
2. cd pytorch
3. conda install pkg-config libuv
4. ccmake build -> Make sure USE_DISTRIBUTED is set to ON -> Press c to set new configuration -> press q
5. pip install .

After that it should work.

Please sign in to comment.