Skip to content

Commit

Permalink
Inference: general tensor-parallel examples (#144)
Browse files Browse the repository at this point in the history
* add T5 example using tensor-parallelism

* refine t5 test

* add more tests to try the Tensor-Parallel inference

* remove pdb

* add tests for GPT-J and wav2vec2 model architectures

Co-authored-by: Ammar Ahmad Awan <[email protected]>
  • Loading branch information
RezaYazdaniAminabadi and awan-10 authored Feb 22, 2022
1 parent 068e656 commit 9c48e36
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 2 deletions.
3 changes: 1 addition & 2 deletions inference/huggingface/gpt-neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
.format(local_rank,
world_size))

generator = pipeline('text-generation',
model='EleutherAI/gpt-neo-2.7B',
device=local_rank)
Expand All @@ -25,4 +24,4 @@
replace_method='auto',
replace_with_kernel_inject=True)
string = generator("DeepSpeed is", do_sample=True, min_length=50)
print(string)
print(string)
28 changes: 28 additions & 0 deletions inference/huggingface/test-electra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
from transformers.models.electra.modeling_electra import ElectraLayer

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))

pipe = pipeline('fill-mask', model="google/electra-base-generator",
tokenizer="google/electra-base-generator")

# The inpjection_policy shows two things:
# 1. which layer module we need to add Tensor-Parallelism
# 2. the name of one or several linear layers: a) attention_output (both encoder and decoder),
# and b) transformer output
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float,
injection_policy={ElectraLayer: ('output.dense')}
)
pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe(f"HuggingFace is creating a {pipe.tokenizer.mask_token} that the community uses to solve NLP tasks.")

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)
36 changes: 36 additions & 0 deletions inference/huggingface/test-gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import torch
import deepspeed
import transformers

from deepspeed import module_inject
from transformers import pipeline
from transformers.models.gptj.modeling_gptj import GPTJBlock

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
"***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
.format(local_rank,
world_size))
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")

inp_tokens = tokenizer("DeepSpeed is", return_tensors="pt",)
model = deepspeed.init_inference(model,
mp_size=world_size,
dtype=torch.float,
injection_policy={GPTJBlock: ('attn.out_proj','mlp.fc_out')},
replace_with_kernel_inject=False)

for token in inp_tokens:
if torch.is_tensor(inp_tokens[token]):
inp_tokens[token] = inp_tokens[token].to(f'cuda:{local_rank}')

model.cuda().to(f'cuda:{local_rank}')
string = tokenizer.batch_decode(model.generate(**inp_tokens,min_length=50,))[0]
print(string)
29 changes: 29 additions & 0 deletions inference/huggingface/test-roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
from transformers.models.roberta.modeling_roberta import RobertaLayer

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))

pipe = pipeline('fill-mask', model="roberta-large", device=local_rank)

# The inpjection_policy shows two things:
# 1. which layer module we need to add Tensor-Parallelism
# 2. the name of several linear layers: a) attention_output (both encoder and decoder),
# and b) transformer output

pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float,
injection_policy={RobertaLayer: ('output.dense')}
)

pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe("Hello I'm a <mask> model.")

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)
29 changes: 29 additions & 0 deletions inference/huggingface/test-t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
from transformers.models.t5.modeling_t5 import T5Block

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))

pipe = pipeline("text2text-generation", model="google/t5-v1_1-small", device=local_rank)

# The inpjection_policy shows two things:
# 1. which layer module we need to add Tensor-Parallelism
# 2. the name of several linear layers: a) attention_output (both encoder and decoder),
# and b) transformer output

pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float,
injection_policy={T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')}
)

pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe("Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy")

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)
51 changes: 51 additions & 0 deletions inference/huggingface/test-wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import torch
from jiwer import wer
import os
import torch
import deepspeed
from deepspeed import module_inject
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2EncoderLayer

librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
"***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
.format(local_rank,
world_size))

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

model = deepspeed.init_inference(model,
mp_size=world_size,
dtype=torch.float,
injection_policy={Wav2Vec2EncoderLayer: ('attention.out_proj','feed_forward.output_dense')},
replace_with_kernel_inject=False)
model.to(f'cuda:{local_rank}')
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch

librispeech_eval = librispeech_eval.map(map_to_array)

def map_to_pred(batch):
input_values = processor(batch["speech"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to(f'cuda:{local_rank}')).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])

print("WER:", wer(result["text"], result["transcription"]))

0 comments on commit 9c48e36

Please sign in to comment.