Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO-Offload passing model tests #374

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

from deepspeed.utils import logger
#Toggle this to true to enable correctness test
Expand Down Expand Up @@ -153,10 +154,13 @@ def __init__(self,

self.reduce_scatter = reduce_scatter

self.overlap_comm = overlap_comm or cpu_offload
self.overlap_comm = overlap_comm

self.cpu_offload = cpu_offload

self.deepspeed_adam_offload = (cpu_offload
and type(init_optimizer) == DeepSpeedCPUAdam)

self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'

self.dp_process_group = dp_process_group
Expand Down Expand Up @@ -1358,12 +1362,15 @@ def step(self, closure=None):
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
timers('optimizer_gradients').start()
timers('optimizer_gradients').stop()
timers('optimizer_step').start()
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
timers('optimizer_allgather').stop()
return

timers('optimizer_gradients').start()
norm_groups = []
single_partition_grad_groups = []
skip = False
Expand Down Expand Up @@ -1405,24 +1412,27 @@ def step(self, closure=None):
single_partition_grad_groups.append(single_grad_partition)

self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_gradients').stop()

#torch.set_num_threads(12)
timers('optimizer_step').start()
if self.cpu_offload:
# self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
self.optimizer.step()
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
if self.deepspeed_adam_offload:
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#self.optimizer.step()
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
# fp16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()

#get rid of the fp32 gradients. Not needed anymore
for group in self.single_partition_of_fp32_groups:
group.grad = None
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None

for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)

timers('optimizer_step').stop()
timers.log(names=['optimizer_step'])

if self.cpu_offload:
self.reset_cpu_buffers()
Expand Down Expand Up @@ -1469,6 +1479,10 @@ def step(self, closure=None):
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

timers.log(
names=['optimizer_gradients',
'optimizer_step',
'optimizer_allgather'])
see_memory_usage('After zero_optimizer step')
return

Expand Down Expand Up @@ -1574,14 +1588,14 @@ def backward(self, loss, retain_graph=False):

if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size * 4.5),
buf_0 = torch.empty(int(self.reduce_bucket_size * 2.5),
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)

# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(int(self.reduce_bucket_size * 4.5),
buf_1 = torch.empty(int(self.reduce_bucket_size * 2.5),
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
Expand Down
87 changes: 50 additions & 37 deletions tests/model/Megatron_GPT2/run_func_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,12 @@ def test_mp4_gpu4_node1_zero2(self):
basic_run_config = test_config
succ = self.run_test(basic_run_config, 0.01)
self.assertTrue(succ)

partition_activation_config = test_config
succ = self.run_partition_activations_test(partition_activation_config, 0.01)
self.assertTrue(succ)

def test_mp1_gpu1_node1_zero2_offload(self):
def test_mp1_gpu1_node1_zero2_ds_offload(self):
test_config = {
"mp": 1,
"gpus": 1,
Expand All @@ -235,10 +236,10 @@ def test_mp1_gpu1_node1_zero2_offload(self):
"json": "ds_config_func_bs4_zero2_offload.json",
"cpu_optimizer": True,
}
succ = self.run_test(test_config, 0.01)
succ = self.run_test(test_config, 0.02)
self.assertTrue(succ)

def test_mp1_gpu2_node1_zero2_offload(self):
def test_mp1_gpu2_node1_zero2_ds_offload(self):
test_config = {
"mp": 1,
"gpus": 2,
Expand All @@ -253,7 +254,7 @@ def test_mp1_gpu2_node1_zero2_offload(self):
"json": "ds_config_func_bs8_zero2_offload.json",
"cpu_optimizer": True,
}
succ = self.run_test(test_config, 0.01)
succ = self.run_test(test_config, 0.02)
self.assertTrue(succ)

def test_mp2_gpu4_node1_zero2_gas(self):
Expand All @@ -278,7 +279,7 @@ def test_mp2_gpu4_node1_zero2_gas(self):
succ = self.run_partition_activations_test(test_config, 0.01)
self.assertTrue(succ)

def test_mp2_gpu4_node1_zero2_offload(self):
def test_mp2_gpu4_node1_zero2_ds_offload(self):
test_config = {
"mp": 2,
"gpus": 4,
Expand All @@ -295,14 +296,14 @@ def test_mp2_gpu4_node1_zero2_offload(self):
}

basic_run_config = test_config
succ = self.run_test(basic_run_config, 0.01)
succ = self.run_test(basic_run_config, 0.02)
self.assertTrue(succ)

partition_activation_config = test_config
succ = self.run_partition_activations_test(partition_activation_config, 0.01)
succ = self.run_partition_activations_test(partition_activation_config, 0.02)
self.assertTrue(succ)

def test_mp4_gpu4_node1_zero2_offload(self):
def test_mp4_gpu4_node1_zero2_ds_offload(self):
test_config = {
"mp": 4,
"gpus": 4,
Expand All @@ -319,14 +320,14 @@ def test_mp4_gpu4_node1_zero2_offload(self):
}

basic_run_config = test_config
succ = self.run_test(basic_run_config, 0.01)
succ = self.run_test(basic_run_config, 0.02)
self.assertTrue(succ)

partition_activation_config = test_config
succ = self.run_partition_activations_test(partition_activation_config, 0.01)
succ = self.run_partition_activations_test(partition_activation_config, 0.02)
self.assertTrue(succ)

def test_mp1_gpu1_node1_zero2_cpu_optimizer(self):
def test_mp1_gpu1_node1_zero2_torch_offload(self):
test_config = {
"mp": 1,
"gpus": 1,
Expand All @@ -338,14 +339,15 @@ def test_mp1_gpu1_node1_zero2_cpu_optimizer(self):
"seq_length": SEQ_LEN,
"heads": ATTN_HEADS,
"deepspeed": False,
"json": "ds_config_func_bs4_zero2.json",
"json": "ds_config_func_bs4_zero2_offload.json",
"cpu_optimizer": True,
"test_torch_offload": True,
}

succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)

def test_mp1_gpu2_node1_zero2_cpu_optimizer(self):
def test_mp1_gpu2_node1_zero2_torch_offload(self):
test_config = {
"mp": 1,
"gpus": 2,
Expand All @@ -357,14 +359,15 @@ def test_mp1_gpu2_node1_zero2_cpu_optimizer(self):
"seq_length": SEQ_LEN,
"heads": ATTN_HEADS,
"deepspeed": False,
"json": "ds_config_func_bs8_zero2.json",
"json": "ds_config_func_bs8_zero2_offload.json",
"cpu_optimizer": True,
"test_torch_offload": True,
}

succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)

def test_mp2_gpu4_node1_zero2_cpu_optimizer(self):
def test_mp2_gpu4_node1_zero2_torch_offload(self):
test_config = {
"mp": 2,
"gpus": 4,
Expand All @@ -376,8 +379,9 @@ def test_mp2_gpu4_node1_zero2_cpu_optimizer(self):
"seq_length": SEQ_LEN,
"heads": ATTN_HEADS,
"deepspeed": False,
"json": "ds_config_func_bs8_zero2.json",
"json": "ds_config_func_bs8_zero2_offload.json",
"cpu_optimizer": True,
"test_torch_offload": True,
}

basic_run_config = test_config
Expand All @@ -388,7 +392,7 @@ def test_mp2_gpu4_node1_zero2_cpu_optimizer(self):
succ = self.run_partition_activations_test(partition_activation_config, 0.01)
self.assertTrue(succ)

def test_mp4_gpu4_node1_zero2_cpu_optimizer(self):
def test_mp4_gpu4_node1_zero2_torch_offload(self):
test_config = {
"mp": 4,
"gpus": 4,
Expand All @@ -400,8 +404,9 @@ def test_mp4_gpu4_node1_zero2_cpu_optimizer(self):
"seq_length": SEQ_LEN,
"heads": ATTN_HEADS,
"deepspeed": False,
"json": "ds_config_func_bs8_zero2.json",
"json": "ds_config_func_bs8_zero2_offload.json",
"cpu_optimizer": True,
"test_torch_offload": True,
}

basic_run_config = test_config
Expand Down Expand Up @@ -439,11 +444,7 @@ def run_partition_activations_test(self, test_config, r_tol):

deepspeed_config = test_config["json"]
baseline_deepspeed_config = False

if "cpu_optimizer" in test_config and test_config["cpu_optimizer"]:
cpu_optimizer_flag = "--cpu-optimizer"
else:
cpu_optimizer_flag = ""
cpu_optimizer_flag = self.gen_cpu_optimizer_flag(test_config, True)

# baseline run...
# turnoff deepspeed if baseline deepspeed config
Expand All @@ -469,6 +470,7 @@ def run_partition_activations_test(self, test_config, r_tol):

# DeepSpeed run...
test_config["deepspeed"] = True
cpu_optimizer_flag = self.gen_cpu_optimizer_flag(test_config, False)
test_config[
"other_args"] = f"\"--deepspeed-activation-checkpointing {cpu_optimizer_flag}\""
test_config["json"] = deepspeed_config
Expand All @@ -488,11 +490,7 @@ def run_test(self, test_config, r_tol):

deepspeed_config = test_config["json"]
baseline_deepspeed_config = False

if "cpu_optimizer" in test_config and test_config["cpu_optimizer"]:
cpu_optimizer_flag = "--cpu-optimizer"
else:
cpu_optimizer_flag = ""
cpu_optimizer_flag = self.gen_cpu_optimizer_flag(test_config, True)

# baseline run...
# turn off deepspeed if a baseline deepspeed config
Expand Down Expand Up @@ -520,6 +518,7 @@ def run_test(self, test_config, r_tol):

# DeepSpeed run...
test_config["deepspeed"] = True
cpu_optimizer_flag = self.gen_cpu_optimizer_flag(test_config, False)
test_config["other_args"] = f"\"{cpu_optimizer_flag}\""

print("{0}: DeepSpeed run.".format(self.id()))
Expand Down Expand Up @@ -551,25 +550,39 @@ def check_parity(self, base_file, test_file, r_tol):

return True

def gen_cpu_optimizer_flag(self, test_config, is_baseline):
if 'cpu_optimizer' in test_config and test_config['cpu_optimizer']:
cpu_optimizer_flag = "--cpu-optimizer"
if is_baseline:
cpu_optimizer_flag += " --cpu_torch_adam"
return cpu_optimizer_flag
if 'test_torch_offload' in test_config and test_config['test_torch_offload']:
cpu_optimizer_flag += " --cpu_torch_adam"
return cpu_optimizer_flag
else:
cpu_optimizer_flag = ""

return cpu_optimizer_flag


def suite():
suite = unittest.TestSuite()

suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_fp16'))

# Baseline = Megatron + Torch.Optim.Adam
# Test = Megatron + Torch.Optim.Adam + ZeRO-Stage-2
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2_cpu_optimizer'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2_cpu_optimizer'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_cpu_optimizer'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2_cpu_optimizer'))
# Test = Megatron + Torch.Optim.Adam + ZeRO-Offload
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2_torch_offload'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2_torch_offload'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_torch_offload'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2_torch_offload'))

# Baseline = Megatron + Torch.Optim.Adam
# Test = Megatron + DeepSpeedAdam + ZeRO-Offload
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2_offload'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2_offload'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_offload'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2_offload'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2_ds_offload'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2_ds_offload'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_ds_offload'))
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2_ds_offload'))

suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero1'))
Expand Down
2 changes: 1 addition & 1 deletion tests/model/run_sanity_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def pytest_hack(runner_result):


def test_megatron():
runner = unittest.TextTestRunner(failfast=False)
runner = unittest.TextTestRunner(failfast=True)
pytest_hack(runner.run(Megatron_GPT2.suite()))


Expand Down