Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference: general tensor-parallel examples #144

Merged
merged 6 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions inference/huggingface/gpt-neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
.format(local_rank,
world_size))

generator = pipeline('text-generation',
model='EleutherAI/gpt-neo-2.7B',
device=local_rank)
Expand All @@ -25,4 +24,4 @@
replace_method='auto',
replace_with_kernel_inject=True)
string = generator("DeepSpeed is", do_sample=True, min_length=50)
print(string)
print(string)
28 changes: 28 additions & 0 deletions inference/huggingface/test-electra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
from transformers.models.electra.modeling_electra import ElectraLayer

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))

pipe = pipeline('fill-mask', model="google/electra-base-generator",
tokenizer="google/electra-base-generator")

# The inpjection_policy shows two things:
# 1. which layer module we need to add Tensor-Parallelism
# 2. the name of one or several linear layers: a) attention_output (both encoder and decoder),
# and b) transformer output
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float,
injection_policy={ElectraLayer: ('output.dense')}
)
pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe(f"HuggingFace is creating a {pipe.tokenizer.mask_token} that the community uses to solve NLP tasks.")

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)
36 changes: 36 additions & 0 deletions inference/huggingface/test-gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import torch
import deepspeed
import transformers

from deepspeed import module_inject
from transformers import pipeline
from transformers.models.gptj.modeling_gptj import GPTJBlock

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
"***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
.format(local_rank,
world_size))
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")

inp_tokens = tokenizer("DeepSpeed is", return_tensors="pt",)
model = deepspeed.init_inference(model,
mp_size=world_size,
dtype=torch.float,
injection_policy={GPTJBlock: ('attn.out_proj','mlp.fc_out')},
replace_with_kernel_inject=False)

for token in inp_tokens:
if torch.is_tensor(inp_tokens[token]):
inp_tokens[token] = inp_tokens[token].to(f'cuda:{local_rank}')

model.cuda().to(f'cuda:{local_rank}')
string = tokenizer.batch_decode(model.generate(**inp_tokens,min_length=50,))[0]
print(string)
29 changes: 29 additions & 0 deletions inference/huggingface/test-roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
from transformers.models.roberta.modeling_roberta import RobertaLayer

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))

pipe = pipeline('fill-mask', model="roberta-large", device=local_rank)

# The inpjection_policy shows two things:
# 1. which layer module we need to add Tensor-Parallelism
# 2. the name of several linear layers: a) attention_output (both encoder and decoder),
# and b) transformer output

pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float,
injection_policy={RobertaLayer: ('output.dense')}
)

pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe("Hello I'm a <mask> model.")

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)
29 changes: 29 additions & 0 deletions inference/huggingface/test-t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
from transformers.models.t5.modeling_t5 import T5Block

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))

pipe = pipeline("text2text-generation", model="google/t5-v1_1-small", device=local_rank)

# The inpjection_policy shows two things:
# 1. which layer module we need to add Tensor-Parallelism
# 2. the name of several linear layers: a) attention_output (both encoder and decoder),
# and b) transformer output

pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float,
injection_policy={T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')}
)

pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe("Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy")

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)
51 changes: 51 additions & 0 deletions inference/huggingface/test-wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import torch
from jiwer import wer
import os
import torch
import deepspeed
from deepspeed import module_inject
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2EncoderLayer

librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
"***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
.format(local_rank,
world_size))

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

model = deepspeed.init_inference(model,
mp_size=world_size,
dtype=torch.float,
injection_policy={Wav2Vec2EncoderLayer: ('attention.out_proj','feed_forward.output_dense')},
replace_with_kernel_inject=False)
model.to(f'cuda:{local_rank}')
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch

librispeech_eval = librispeech_eval.map(map_to_array)

def map_to_pred(batch):
input_values = processor(batch["speech"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to(f'cuda:{local_rank}')).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])

print("WER:", wer(result["text"], result["transcription"]))