Skip to content

Commit

Permalink
Add FALCON Auto-TP Support (#3640)
Browse files Browse the repository at this point in the history
* Add FALCON auto-tp support
* added (skipped) unit test, refactored code to be more readable

---------

Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
RezaYazdaniAminabadi and mrwyattii authored Jul 5, 2023
1 parent 385e89d commit f3c93b0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 33 deletions.
4 changes: 4 additions & 0 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def tp_parser(model):
gem_list = gem_list + [layer]
elif 'down_proj' in layer:
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
gem_list = list(set(gem_list))
Expand Down
50 changes: 18 additions & 32 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,38 +426,14 @@ def _slice_embedding(child, name, conv_linear_layer):
def update_mp_params(child):
if getattr(child, "replaced", False) == True:
return
if hasattr(child, 'n_heads'):
assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(
child.n_heads, mp_size)
child.n_heads = child.n_heads // mp_size
if hasattr(child, 'inner_dim'):
assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format(
child.inner_dim, mp_size)
child.inner_dim = child.inner_dim // mp_size
if hasattr(child, 'num_heads'):
assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format(
child.num_heads, mp_size)
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format(
child.num_attention_heads, mp_size)
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'):
assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format(
child.num_attn_heads, mp_size)
child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'):
assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format(
child.all_head_size, mp_size)
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format(
child.embed_dim, mp_size)
child.embed_dim = child.embed_dim // mp_size
if hasattr(child, 'hidden_size'):
assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(
child.hidden_size, mp_size)
child.hidden_size = child.hidden_size // mp_size
for param in [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
"all_head_size", "embed_dim", "hidden_size"
]:
if hasattr(child, param):
param_val = getattr(child, param)
assert param_val % mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({mp_size})"
setattr(child, param, param_val // mp_size)
setattr(child, "replaced", True)

conv_linear_layer = False
Expand Down Expand Up @@ -495,6 +471,16 @@ def _replace_module(r_module, prev_name='', prev_class_name=''):
if child.__class__ in linear_policies:
setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
conv_linear_layer))
elif any(isinstance(child, lp) for lp in linear_policies):
# Added for falcon model support
# Note: isinstance will account for class inheritance, child.__class__ does not
key = None
for lp in linear_policies:
if isinstance(child, lp):
key = lp
break
assert key is not None
setattr(r_module, name, linear_policies[key](child, prev_name + '.' + name, conv_linear_layer))
else:
update_mp_params(child)
_replace_module(child, name, class_name)
Expand Down
42 changes: 41 additions & 1 deletion tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from unit.common import DistributedTest
from packaging import version as pkg_version
from deepspeed.ops.op_builder import OpBuilder
from transformers import pipeline
from transformers import pipeline, AutoTokenizer
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.roberta.modeling_roberta import RobertaLayer
from huggingface_hub import HfApi
Expand Down Expand Up @@ -380,6 +380,46 @@ def test(
assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"])
class TestAutoTP(DistributedTest):
world_size = 1

def test(
self,
model_w_task,
query,
inf_kwargs,
assert_fn,
):
# TODO: enable this test for H100 tests
pytest.skip("Not enough GPU memory for this on V100 runners")
model, task = model_w_task
dtype = torch.bfloat16
local_rank = int(os.getenv("LOCAL_RANK", "0"))

# We have to load these large models on CPU with pipeline because not
# enough GPU memory
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
pipe = pipeline(task,
model=model,
tokenizer=tokenizer,
torch_dtype=dtype,
trust_remote_code=True,
device=torch.device("cpu"),
framework="pt")
#bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(get_accelerator().device_name(local_rank))
ds_output = pipe(query, **inf_kwargs)

#print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
#assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize(
"model_w_task, injection_policy",
Expand Down

0 comments on commit f3c93b0

Please sign in to comment.