Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed Inference] HF Integration #14426

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
48 changes: 48 additions & 0 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,54 @@
logger = logging.get_logger(__name__)


inference_custom_map = dict(
electra=dict(ElectraLayer=("output.dense")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 This will only parallelize the output.dense layer and the other parts will be duplicated on all GPUs, resulting in memory inefficiency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To parallelize all parts, all layer information must be input. This will be similar to the policy of the existing DeepSpeed Inference, and it will not be very different from the policy I used in Parallelformers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RezaYazdaniAminabadi Am I right? Or any other your opinions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the PR says this is very early. So basically all I did is converting an example that Reza gave me to have it integrated into HF Trainer. So treating it as a black box for now and waiting for Reza to complete the project before trying to understand how it works.

But I trust Reza will be happy to answer your question.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyunwoongko, this only shows that which linear layers would require an all_reduce. So, this is not going to use the same policy as when injecting the kernels. You can find more detail on how the other layers are partitioned on the replace_module function in DeepSpeed. But, basically this policy here is just showing which part need to be partitioned horizontally, whereas the rest are partitioned vertically. Does it make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanatory notes, @RezaYazdaniAminabadi - I have added them to the file, so this is covered.

roberta=dict(RobertaLayer=("output.dense")),
t5=dict(T5Block=("SelfAttention.o", "EncDecAttention.o", "DenseReluDense.wo")),
albert=dict(AlbertLayer=("attention.dense", "ffn_output")),
bart=dict(BartEncoderLayer=("self_attn.out_proj", "fc2")),
deberta=dict(DebertaLayer=("output.dense")),
deberta_v2=dict(DebertaV2Layer=("output.dense")),
wav2vec2=dict(Wav2Vec2EncoderLayer=("attention.out_proj", "feed_forward.output_dense")),
)

inference_auto_map = ["gpt_neo", "gptj", "gpt2", "bert"]


def deepspeed_inference_init(trainer):
"""
XXX:
"""

dep_version_check("deepspeed")
import deepspeed

args = trainer.args

model_arch = trainer.model.config.model_type

if model_arch in inference_auto_map:
kwargs = dict(
replace_method="auto",
replace_with_kernel_inject=True,
)
elif model_arch in inference_custom_map:
kwargs = dict(injection_policy=inference_custom_map[model_arch])
else:
raise ValueError(
f"[Deepspeed Inference] {model_arch} hasn't yet been mapped out, please file an Issue to request support for it"
)

deepspeed_inference_engine = deepspeed.init_inference(
trainer.model,
mp_size=args.world_size,
dtype=torch.half if args.fp16 else torch.float, # XXX: add bf16 once ds supports it
**kwargs,
)

return deepspeed_inference_engine


def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None

Expand Down
18 changes: 17 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .deepspeed import deepspeed_inference_init, deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -362,6 +362,7 @@ def __init__(
if (
self.is_model_parallel
or args.deepspeed
or args.deepspeed_inference
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
Expand Down Expand Up @@ -1829,6 +1830,11 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = dict(device=self.args.device)
# if self.args.deepspeed_inference:
# print(data.dtype)
# print(kwargs)
# return data.to("cuda:0")

if self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
Expand Down Expand Up @@ -2274,6 +2280,12 @@ def evaluation_loop(
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine

if self.args.deepspeed_inference:
deepspeed_inference_engine = deepspeed_inference_init(self)
self.model = deepspeed_inference_engine.module
self.model_wrapped = deepspeed_inference_engine
self.deepspeed = deepspeed_inference_engine

model = self._wrap_model(self.model, training=False)

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
Expand Down Expand Up @@ -2437,6 +2449,10 @@ def _pad_across_processes(self, tensor, pad_index=-100):
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
they can safely be gathered.
"""
# XXX: hangs here with 2 gpus if we don't return
# if self.args.deepspeed_inference:
# return tensor

if isinstance(tensor, (list, tuple)):
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
elif isinstance(tensor, dict):
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ class TrainingArguments:
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
`ds_config.json`) or an already loaded json file as a `dict`"
deepspeed_inference (`bool`, *optional*):
Enable [Deepspeed Inference](https://www.deepspeed.ai/tutorials/inference-tutorial). This is an
experimental feature and its API may change in the future.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
Expand Down Expand Up @@ -635,9 +638,13 @@ class TrainingArguments:
deepspeed: Optional[str] = field(
default=None,
metadata={
"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
"help": "Enable DeepSpeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
},
)
deepspeed_inference: bool = field(
default=False,
metadata={"help": "Enable DeepSpeed Inference"},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
Expand Down Expand Up @@ -992,7 +999,7 @@ def _setup_devices(self) -> "torch.device":
self.local_rank = sm_dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.deepspeed:
elif self.deepspeed or self.deepspeed_inference:
# deepspeed inits torch.distributed internally
from .deepspeed import is_deepspeed_available

Expand Down