Skip to content

Commit

Permalink
add unit test for auto parallel amp acc align. (#67535)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Aug 20, 2024
1 parent 3650481 commit 3901873
Showing 1 changed file with 114 additions and 10 deletions.
124 changes: 114 additions & 10 deletions test/auto_parallel/hybrid_strategy/semi_auto_llama_acc_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import hashlib
import os

os.environ["FLAGS_enable_pir_api"] = str(1)
import random
from functools import reduce

Expand Down Expand Up @@ -91,16 +93,110 @@ def __init__(self):
self.pp = int(os.getenv("pp"))
if os.getenv("use_sp") == "true":
self.config.sequence_parallel = True
self.gradient_accumulation_steps = int(os.getenv("acc_step"))

self.strategy = dist.Strategy()

# amp config
amp = self.strategy._amp
if os.getenv("amp"):
amp.enbale = os.getenv("amp")
if os.getenv("amp_dtype"):
amp.dtype = os.getenv("amp_dtype")
if os.getenv("amp_level"):
amp.level = os.getenv("amp_level")
if os.getenv("amp_master_grad"):
amp.use_master_grad = os.getenv("amp_master_grad")
if os.getenv("scale_loss"):
amp.init_loss_scaling = os.getenv("scale_loss")
if os.getenv("amp_custom_black_list"):
amp.custom_black_list = os.getenv("amp_custom_black_list")
if os.getenv("amp_custom_white_list"):
amp.custom_white_list = os.getenv("amp_custom_white_list")

self.gradient_accumulation_steps = 1
if os.getenv("acc_step"):
self.gradient_accumulation_steps = int(os.getenv("acc_step"))

if self.gradient_accumulation_steps > 1:
self.strategy.gradient_merge.enable = True
self.strategy.gradient_merge.k_steps = (
self.gradient_accumulation_steps
)
self.strategy.gradient_merge.avg = False

self.config.recompute = False
self.config.sep_parallel_degree = 1

self.run_step_dynamic = 10
self.run_step = 10
self.run_step_dy2static = (
self.run_step_dynamic // self.gradient_accumulation_steps
self.run_step // self.gradient_accumulation_steps
)

self.init_dist_env()
def run_llama(self, to_static=0):
# model
model = LlamaForCausalLMAuto(self.config)
criterion = LlamaPretrainingCriterionAuto(self.config)
if self.strategy._amp.enable and self.strategy._amp.level == "O2":
paddle.amp.decorate(
models=model,
level=self.strategy._amp.level,
dtype=self.strategy._amp.dtype,
master_grad=self.strategy._amp.use_master_grad,
)

# optimizer
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
)
optimizer = create_optimizer(model, lr_scheduler)
optimizer = dist.shard_optimizer(optimizer)

# dataloader
train_dataset = RandomDataset(self.config.seq_length)
train_sampler = BatchSampler(
train_dataset,
batch_size=2,
shuffle=True,
drop_last=True,
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=0,
)
dist_loader = dist.shard_dataloader(
dataloader=train_dataloader,
meshes=[get_mesh(0), get_mesh(1)],
shard_dims="dp",
)

if to_static:
model = dist.to_static(
model, dist_loader, criterion, optimizer, strategy=self.strategy
)
model.train()
md5_list = []
for step, inputs in enumerate(dist_loader()):
if step >= self.run_step:
break
input_ids, labels = inputs
if to_static:
loss = model(input_ids, labels)
if loss is None:
numpy_array = np.array([])
else:
numpy_array = np.array(loss)
array_bytes = numpy_array.tobytes()
md5_list.append(hashlib.md5(array_bytes).hexdigest())
else:
logits = model(input_ids)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
optimizer.clear_grad()
md5_list.append(loss._local_value()._md5sum())
lr_scheduler.step()
return md5_list

def init_dist_env(self):
order = ["dp", "pp", "mp"]
Expand Down Expand Up @@ -154,7 +250,7 @@ def run_dynamic(self):
model.train()
#####
for step, inputs in enumerate(dist_loader()):
if step >= self.run_step_dynamic:
if step >= self.run_step:
break

input_ids, labels = inputs
Expand Down Expand Up @@ -248,11 +344,19 @@ def run_dy2static(self):

def run_test_cases(self):
self.init_dist_env()
dy_loss_md5 = self.run_dynamic()
self.init_dist_env()
st_loss_md5 = self.run_dy2static()
if int(dist.get_rank()) in [2, 3, 6, 7]:
assert dy_loss_md5 == st_loss_md5
if self.gradient_accumulation_steps > 1:
dy_loss_md5 = self.run_dynamic()
self.init_dist_env()
st_loss_md5 = self.run_dy2static()
if int(dist.get_rank()) in [2, 3, 6, 7]:
assert dy_loss_md5 == st_loss_md5
else:
dy_loss_md5 = self.run_llama(to_static=0)
self.init_dist_env()
st_loss_md5 = self.run_llama(to_static=1)
assert len(dy_loss_md5) == len(st_loss_md5)
for idx in range(len(dy_loss_md5)):
assert dy_loss_md5[idx] == st_loss_md5[idx]


if __name__ == '__main__':
Expand Down

0 comments on commit 3901873

Please sign in to comment.