Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Feb 2, 2021
1 parent c414a3b commit d4b4476
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
7 changes: 3 additions & 4 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def rewrite_logs(d):
return new_d



import torch


Expand Down Expand Up @@ -327,7 +326,7 @@ def ensure_divisibility(numerator, denominator):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
#print(f"DP ranks: {list(ranks)}")
# print(f"DP ranks: {list(ranks)}")
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

Expand All @@ -339,7 +338,7 @@ def ensure_divisibility(numerator, denominator):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
#print(f"MP ranks: {list(ranks)}")
# print(f"MP ranks: {list(ranks)}")
_MODEL_PARALLEL_GROUP = group
_MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

Expand Down Expand Up @@ -550,7 +549,7 @@ def init_deepspeed(trainer, num_training_steps, mpu):
model=model,
model_parameters=model_parameters,
config_params=config,
mpu = mpu,
mpu=mpu,
)

return model, optimizer, lr_scheduler
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def __init__(
else:
self.is_model_parallel = False


self.mpu = None
# XXX: for now hack over naive MP to have the same behavior
if len(self.args.pipeline):
Expand All @@ -271,7 +270,7 @@ def __init__(
pp_args_str = self.args.pipeline.split()
if len(pp_args_str):
for x in pp_args_str:
k,v = x.split("=")
k, v = x.split("=")
pp_args[k] = v

if "chunks" in pp_args:
Expand Down Expand Up @@ -301,10 +300,11 @@ def __init__(
# 2D Parallel
if self.args.deepspeed:
from .integrations import MPU

self.mpu = MPU()
#n_gpus = torch.distributed.get_world_size()
# n_gpus = torch.distributed.get_world_size()
# XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable
#n_gpus_per_mp = n_gpus/2
# n_gpus_per_mp = n_gpus/2
# at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP
self.mpu.initialize_model_parallel(pp_args["n_gpus_per_mp"])

Expand Down

0 comments on commit d4b4476

Please sign in to comment.