Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism related mappings #15360

Closed
wants to merge 20 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 177 additions & 1 deletion src/transformers/utils/model_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,186 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import importlib
from collections import defaultdict
from math import ceil
hyunwoongko marked this conversation as resolved.
Show resolved Hide resolved


class TPInfo(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a docstring here explaining what this class does.

Also, the name is unclear to someone who is not familiar with all the model parallelism jargon. Let's expand to TensorParallelismInfo as we usually use descriptive names in Transformers.

def __init__(
self,
*name,
combined_qkv: bool = False,
reverse: bool = False,
):
Comment on lines +22 to +27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(
self,
*name,
combined_qkv: bool = False,
reverse: bool = False,
):
def __init__(self, *name, combined_qkv: bool = False, reverse: bool = False):

This fits in one line (our char limit is 119).

self.name = name
self.combined_qkv = combined_qkv
self.reverse = reverse

def __str__(self):
return f"{self.__class__.__qualname__}({self.name})"

def __repr__(self):
return self.__str__()


Col = type("COLUMN", (TPInfo,), {"code": "Col"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remark as above about naming. Let's go for the full Column here, there is no point sparing three characters :-)

Row = type("ROW", (TPInfo,), {"code": "Row"})
Update = type("UPDATE", (TPInfo,), {"code": "Update"})


class TPMapping(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here TensorParallelismMapping. We also need a docstring to explain what this does and how to expand it.

__MAPPING__ = dict(
Albert=[
Col("query", "key", "value", "ffn"),
Row("attention.dense", "ffn_output"),
Update("num_attention_heads", "all_head_size"),
],
Bart=[
Col("q_proj", "k_proj", "v_proj", "fc1"),
Row("out_proj", "fc2"),
Update("embed_dim", "num_heads"),
],
Bert=[
Col("query", "key", "value", "intermediate.dense"),
Row("output.dense"),
Update("num_attention_heads", "all_head_size"),
],
T5=[
Col("q", "k", "v", "DenseReluDense.wi"),
Row("o", "DenseReluDense.wo", "relative_attention_bias"),
Update("d_model", "n_heads", "inner_dim"),
],
GPT2=[
Col("c_attn", reverse=True, combined_qkv=True),
Col("c_fc", "q_attn", reverse=True),
Row("c_proj", reverse=True),
Update("embed_dim", "split_size", "num_heads"),
],
GPTNeo=[
Col("q_proj", "k_proj", "v_proj", "c_fc"),
Row("out_proj", "c_proj"),
Update("embed_dim", "num_heads"),
],
GPTJ=[
Col("q_proj", "k_proj", "v_proj", "fc_in"),
Row("out_proj", "fc_out"),
Update("embed_dim", "num_attention_heads"),
],
Electra=[
Col("query", "key", "value", "intermediate.dense"),
Row("output.dense"),
Update("num_attention_heads", "all_head_size"),
],
Roberta=[
Col("query", "key", "value", "intermediate.dense"),
Row("output.dense"),
Update("num_attention_heads", "all_head_size"),
],
)

def __init__(self):
cache_tp_mapping = {}

for cls_name, mapping in self.__MAPPING__.items():
cls = self._load_class_by_model_name(cls_name)
cache_tp_mapping[cls] = []

for elem in mapping:
for name in elem.name:
copy_elem = copy.deepcopy(elem)
copy_elem.name = name
cache_tp_mapping[cls].append(copy_elem)

self.__MAPPING__ = {cls: defaultdict(list) for cls in cache_tp_mapping}
# clear exist mapping rather than making new mapping dict

for cls, mapping in cache_tp_mapping.items():
for elem in mapping:
self.__MAPPING__[cls][elem.code].append(elem)

@staticmethod
def _load_class_by_model_name(model_name):
transformers = importlib.import_module("transformers")
cls = getattr(transformers, f"{model_name}PreTrainedModel", None)
if cls is None:
cls = getattr(transformers, f"{model_name}PretrainedModel", None)
Comment on lines +117 to +119
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second test should not be necessary. We don't have any PretrainedModel without the capital T (except for bart, but the class without the capital is deprecated and we have a class with the capital).

Copy link
Contributor Author

@hyunwoongko hyunwoongko Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger Can you let me know which version this was applied from? The users of the lower transformers version still need this in the OSLO.

To add a little more, the transformers will also have this mapping, but the OSLO will have it internally as well. This is because this mapping class does not exist for users of lower versions of the transformers. Users with lower versions will use the mapping inside OSLO and users with higher versions will use the mapping from transformers. So this check is not required in transformers, but still required in OSLO.

assert cls is not None, f"Can not import the model named {cls}."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer going for a test and raising an error in Transformers:

Suggested change
assert cls is not None, f"Can not import the model named {cls}."
if cls is not None:
raise ValueError(f"Can not import the model named {cls}.")

Copy link
Contributor Author

@hyunwoongko hyunwoongko Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for correction! I think if cls is None is more correct. I'll apply your suggestion.

return cls

def get_mapping(self, model):
for cls, mapping in self.__MAPPING__.items():
if isinstance(model, cls):
return dict(mapping)
return None

def column_parallel_params(self, model):
mapping = self.get_mapping(model)
if mapping is not None:
return mapping["Col"]

def row_parallel_params(self, model):
mapping = self.get_mapping(model)
if mapping is not None:
return mapping["Row"]

def update_attrs(self, model):
mapping = self.get_mapping(model)
if mapping is not None:
return mapping["Update"]

def search(self, model, param_name):
mapping = self.get_mapping(model)
hyunwoongko marked this conversation as resolved.
Show resolved Hide resolved
hyunwoongko marked this conversation as resolved.
Show resolved Hide resolved
if mapping is None:
raise ValueError(f"{model} does not support tensor parallelism.")
count_contain_elem_in_param = 0
param_split = param_name.split(".")
first_check = []

for code, elems in mapping.items():
for elem in elems:
if elem.name in param_name:
first_check.append(elem)

for elem in first_check:
elem_split = elem.name.split(".")
for split in elem_split:
if split in param_split:
count_contain_elem_in_param += 1
if count_contain_elem_in_param == len(elem_split):
return elem

return None

def is_combined_qkv_param(self, model, param_name):
elem = self.search(model, param_name)
if elem is not None:
return elem.combined_qkv

def get_combined_qkv_degree(self, model, param_name, module):
if self.is_combined_qkv_param(model, param_name) and hasattr(module, "weight"):
bigger = max(module.weight.size(0), module.weight.size(1))
smaller = min(module.weight.size(0), module.weight.size(1))
return bigger // smaller
return 1

def is_reversed_param(self, model, param_name):
hyunwoongko marked this conversation as resolved.
Show resolved Hide resolved
elem = self.search(model, param_name)
if elem is not None:
return elem.reverse

def is_column_parallel(self, model, param_name):
elem = self.search(model, param_name)
if elem is not None:
return elem.code == "Col"

def is_row_parallel(self, model, param_name):
elem = self.search(model, param_name)
if elem is not None:
return elem.code == "Row"


def assert_device_map(device_map, num_blocks):
blocks = list(range(0, num_blocks))

Expand Down