We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I was experimenting loading qwen2 model with world-size 2. I am loading the workers completely on CPU. The following is the code I was testing:
qwen2
import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer from torch.distributed.pipelining import SplitPoint, pipeline, ScheduleGPipe tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") tokenizer.pad_token = tokenizer.eos_token def load_model(rank: int): model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") mb_inputs = tokenizer(("How do you",), return_tensors="pt", padding=True).to(torch.device("cpu")) pipe = pipeline(model, mb_args=(mb_inputs["input_ids"],), split_spec={ 'model.layers.12': SplitPoint.BEGINNING} ) stage = pipe.build_stage(rank, device=torch.device("cpu")) return stage full_batch_prompts = ( "How do you", ) inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=False).to(torch.device("cpu") ) rank = int(os.getenv("RANK")) torch.distributed.init_process_group( "gloo", rank=rank, world_size=2 ) stage = load_model(rank) print('loaded model, now initiating pipeline') schedule = ScheduleGPipe(stage, 1) if rank == 0: args = inputs["input_ids"] print(args.shape, args) else: args = None output = schedule.step(args) print(f'{rank} - op - {output}') if output is not None: next_token_logits = output[0][:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) print(tokenizer.batch_decode(next_token))
The code hangs infinitely. However I got the output from rank 0.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
I was experimenting loading
qwen2
model with world-size 2. I am loading the workers completely on CPU.The following is the code I was testing:
The code hangs infinitely. However I got the output from rank 0.
The text was updated successfully, but these errors were encountered: