Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel injection fixes #601

Merged
merged 5 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject.replace_module import replace_transformer_layer, revert_transformer_layer
from .utils import log_dist

from .pipe import PipelineModule
Expand Down
1 change: 0 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def _replace_module(model, policies):
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
print(f'{orig} -> {new}')
else:
_replace_module(child, policies)

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/ops/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
from . import sparse_attention
from . import transformer

from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_module

from ..git_version_info import compatible_ops as __compatible_ops__
216 changes: 216 additions & 0 deletions deepspeed/ops/module_inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import copy
import torch
import deepspeed

from deepspeed.ops import DeepSpeedTransformerConfig


def _copy_child_transformer_state(new_module, orig_child, pre_layer_norm):
# copy relevant state from original child -> new module
qw = orig_child.attention.self.query.weight
qb = orig_child.attention.self.query.bias
kw = orig_child.attention.self.key.weight
kb = orig_child.attention.self.key.bias
vw = orig_child.attention.self.value.weight
vb = orig_child.attention.self.value.bias

qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)

#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)

new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = orig_child.attention.output.dense.weight
new_module.attn_ob.data = orig_child.attention.output.dense.bias
if pre_layer_norm:
attention_layernorm = orig_child.PostAttentionLayerNorm
else:
attention_layernorm = orig_child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if pre_layer_norm:
intermediate_ff = orig_child.intermediate.dense_act
else:
intermediate_ff = orig_child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = orig_child.output.dense.weight
new_module.output_b.data = orig_child.output.dense.bias
if pre_layer_norm:
transformer_layernorm = orig_child.PreAttentionLayerNorm
else:
transformer_layernorm = orig_child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias


def _replace_transformer_layer(orig_layer_impl, model, transformer_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
transformer_config (dict): deepspeed transformer layer config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
_copy_child_transformer_state(new_module,
child,
transformer_config.pre_layer_norm)

return new_module

return _replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn)


def replace_module(orig_module_impl, model, replacement_module_config):
""" Replace client module
Arguments:
orig_module_impl (torch.nn.Module): original module implementation to replace,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
replacement_module_config (dict): deepspeed replacement module config (e.g., DeepSpeedTransformerConfig) .

Returns:
Updated nn.module with replaced modules
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'

return _replace_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
transformer_config=replacement_module_config)


def _revert_transformer_layer(orig_layer_impl, model, bert_config, transformer_config):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
transformer_config (dict): deepspeed tranformer config used for replacement

Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)

# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data

qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)

orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb

orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data

attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if transformer_config.pre_layer_norm:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b

inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if transformer_config.pre_layer_norm:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b

orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data

transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if transformer_config.pre_layer_norm:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module

return _replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)


def revert_module(orig_module_impl,
model,
orig_module_config,
replacement_module_config):
""" Revert DeepSpeed's module back to original client module
Arguments:
orig_module_impl (torch.nn.Module): the original module that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
orig_module_config (dict): original module configuration
replacement_module_config (dict): replacement deepspeed module configuration

Returns:
Updated nn.module with original bert-style transformer layers
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'

return _revert_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
bert_config=orig_module_config,
transformer_config=replacement_module_config)


def _replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.

Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module_using_policies(model, policy)


def _replace_module_using_policies(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.

Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module_using_policies(child, policies)

return model
28 changes: 25 additions & 3 deletions deepspeed/ops/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class DeepSpeedTransformerConfig(TransformerConfig):
to turn it off in order to be able to reproduce the same result through the regular kernel execution.

huggingface: Enbale if using the HuggingFace interface style for sending out the forward results.

training: Enable for training rather than inference.
"""
def __init__(self,
batch_size=-1,
Expand All @@ -111,7 +113,8 @@ def __init__(self,
adjust_init_range=True,
attn_dropout_checkpoint=False,
stochastic_mode=False,
huggingface=False):
huggingface=False,
training=True):
super(DeepSpeedTransformerConfig,
self).__init__(
batch_size,
Expand All @@ -131,7 +134,7 @@ def __init__(self,
self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required
self.adjust_init_range = adjust_init_range
self.test_gemm = False
self.training = True
self.training = training
self.is_grad_enabled = True
self.attn_dropout_checkpoint = attn_dropout_checkpoint
self.stochastic_mode = stochastic_mode
Expand Down Expand Up @@ -248,7 +251,7 @@ def forward(ctx,
norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"]))
norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"]))

if config.is_grad_enabled:
if config.is_grad_enabled and config.training:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
if (config.pre_layer_norm and config.normalize_invertible):
ctx.save_for_backward(input_mask,
attn_qkvw,
Expand Down Expand Up @@ -405,6 +408,25 @@ def backward(ctx, grad_output):
norm_w,
norm_b)

# This appears to be an effective way to release context memory
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
ctx.qkv_tf = None
ctx.soft_inp = None
ctx.ctx_bufB = None
ctx.gelu_inp = None
ctx.ff2_inp = None
ctx.attn_o_inp = None
ctx.ff1_inp = None
ctx.add_res = None
ctx.inp_norm = None
ctx.config = None
ctx.attn_layer_norm_mean = None
ctx.layer_norm_mean = None
ctx.attn_prob_dropout_mask = None
ctx.attn_output_dropout_mask = None
ctx.layer_output_dropout_mask = None
ctx.attn_layer_norm_var = None
ctx.layer_norm_var = None

return (grad_input,
None,
None,
Expand Down
19 changes: 4 additions & 15 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,21 +530,10 @@ def see_memory_usage(message):
# Print message except when distributed but not rank 0
logger.info(message)
logger.info(
"Memory Allocated %s GigaBytes ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
)
logger.info(
"Max Memory Allocated %s GigaBytes",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
)
logger.info(
"Cache Allocated %s GigaBytes",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
)
logger.info(
"Max cache Allocated %s GigaBytes",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
)
f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ")


def call_to_str(base, *args, **kwargs):
Expand Down