Skip to content

Commit

Permalink
[zero] faster flatten/unflatten (cpp version) (#910)
Browse files Browse the repository at this point in the history
* faster flatten/unflatten with apex

* switch to cpp flatten/unflatten

* style

* better comment

* missing import

* switch to build ops at run time

* fixes

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
stas00 and tjruwase authored Apr 14, 2021
1 parent 7003d44 commit 8b8ed2a
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 177 deletions.
124 changes: 65 additions & 59 deletions deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import math
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from collections import defaultdict

from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
from deepspeed.utils import logger, log_dist
from deepspeed.ops.op_builder import UtilsBuilder


def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
Expand All @@ -29,54 +29,6 @@ def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_c
return group_paddings


def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dp,
max_elements_per_comm,
pg):
assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}"

num_elements = sum(t.numel() for t in tensor_list)
log_dist("Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])

# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)

# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)

if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size

# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size * dp) - num_elements

log_dist(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}",
ranks=[0])
log_dist(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}",
ranks=[0])

if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]

flat_tensors = _flatten_dense_tensors(aligned_tensor_list)
return flat_tensors


def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
Expand Down Expand Up @@ -127,6 +79,11 @@ def __init__(self,
max_elements_per_comm=5e8,
elastic_checkpoint=True):

# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
"and partition size")
Expand Down Expand Up @@ -209,7 +166,7 @@ def __init__(self,

# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
flat_aligned_params = flatten_dense_tensors_sub_partition_aligned(
flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=self.fp16_groups[i],
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elems_per_comm[i],
Expand All @@ -218,8 +175,8 @@ def __init__(self,

# TODO: I don't think this does anything?
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -455,8 +412,8 @@ def get_all_sub_partition_info(tensor_list,

return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local

@staticmethod
def get_flat_sub_partitions(comm_tensor_list,
def get_flat_sub_partitions(self,
comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
Expand Down Expand Up @@ -527,7 +484,7 @@ def get_flat_sub_partitions(comm_tensor_list,
partition_params.append(my_params) #flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list))
flat_sub_partitions.append(self.flatten(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
# logger.info("padding w. sub partitions to ensure uniform communication")
Expand Down Expand Up @@ -569,6 +526,55 @@ def free_grad_in_param_list(self, param_list):
else:
p.grad = None

def flatten_dense_tensors_sub_partition_aligned(self,
tensor_list,
dp,
max_elements_per_comm,
pg):
assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}"

num_elements = sum(t.numel() for t in tensor_list)
log_dist(
"Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])

# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)

# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)

if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size

# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size * dp) - num_elements

log_dist(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}",
ranks=[0])
log_dist(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}",
ranks=[0])

if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]

flat_tensors = self.flatten(aligned_tensor_list)
return flat_tensors

def reduce_scatter_gradients(self,
postscale_gradients,
gradient_predivide_factor,
Expand Down Expand Up @@ -699,8 +705,8 @@ def step(self, closure=None):

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -903,7 +909,7 @@ def _retrieve_group_sub_partition_weights(self,
sub_partition_idx = (comm_idx * num_partitions) + rank
all_sub_partition_weights[sub_partition_idx] = sub_partition_weights

flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
flat_merged_weights = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=max_elems_per_comm,
Expand Down Expand Up @@ -951,7 +957,7 @@ def _partition_base_optimizer_state(self,
return all_partition_states[0]

alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
flat_merged_partitions = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=max_elems_per_comm,
Expand Down
82 changes: 39 additions & 43 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
'''

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
Expand All @@ -16,9 +15,8 @@
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam

from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.utils import logger
from ...ops.op_builder import UtilsBuilder

#Toggle this to true to enable correctness test
#with gradient partitioning and without
Expand Down Expand Up @@ -52,28 +50,6 @@ def lcm(x, y):
return x * y // gcd(x, y)


# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()

remaining = num_elements % alignment

if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]

num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list

return _flatten_dense_tensors(padded_tensor_list)


def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
Expand Down Expand Up @@ -121,11 +97,6 @@ def __init__(self,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):

# Load pre-installed or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
Expand All @@ -143,6 +114,11 @@ def __init__(self,
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer

# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

self.timers = timers

self.reduce_scatter = reduce_scatter
Expand Down Expand Up @@ -236,7 +212,7 @@ def __init__(self,

#create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
Expand All @@ -247,8 +223,8 @@ def __init__(self,
f"After Flattening and after emptying param group {i} cache")

# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -611,6 +587,27 @@ def report_ipg_memory_usage(self, tag, param_elems):
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)

# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()

remaining = num_elements % alignment

if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]

num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list

return self.flatten(padded_tensor_list)

############### Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
Expand Down Expand Up @@ -1004,7 +1001,7 @@ def are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None

def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = _flatten_dense_tensors(tensors)
flatten_tensor = self.flatten(tensors)

def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
Expand Down Expand Up @@ -1327,7 +1324,7 @@ def get_flat_partition(self,
if return_tensor_list:
return flat_tensor_list

return _flatten_dense_tensors(flat_tensor_list)
return self.flatten(flat_tensor_list)

def free_grad_in_param_list(self, param_list):
for p in param_list:
Expand Down Expand Up @@ -1419,14 +1416,13 @@ def step(self, closure=None):
#create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
single_grad_partition = self.flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i])).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)

Expand Down Expand Up @@ -1507,8 +1503,8 @@ def step(self, closure=None):

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -1749,7 +1745,7 @@ def _restore_from_fp32_weights(self, all_state_dict):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
flat_merged_partitions = self.flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
Expand All @@ -1773,7 +1769,7 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = flatten_dense_tensors_aligned(
flat_merged_partitions = self.flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
Expand Down
Loading

0 comments on commit 8b8ed2a

Please sign in to comment.