diff --git a/deepspeed/runtime/zero/stage1.py b/deepspeed/runtime/zero/stage1.py index 7cd37f904faa..dde8424ceaad 100755 --- a/deepspeed/runtime/zero/stage1.py +++ b/deepspeed/runtime/zero/stage1.py @@ -1,7 +1,6 @@ import math import torch import torch.distributed as dist -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from collections import defaultdict from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups @@ -9,6 +8,7 @@ from deepspeed.runtime.utils import get_grad_norm, CheckOverflow from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES from deepspeed.utils import logger, log_dist +from deepspeed.ops.op_builder import UtilsBuilder def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size): @@ -29,54 +29,6 @@ def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_c return group_paddings -def flatten_dense_tensors_sub_partition_aligned(tensor_list, - dp, - max_elements_per_comm, - pg): - assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}" - - num_elements = sum(t.numel() for t in tensor_list) - log_dist("Total number of elements in model: {}, max elements per com: {}".format( - num_elements, - max_elements_per_comm), - ranks=[0]) - - # Compute aligned partition size based on parameter count - aligned_param_partition_size = math.ceil(num_elements / dp) - - # Compute aligned partition size based on communication size - aligned_comm_partition_size = int(max_elements_per_comm // dp) - - if aligned_param_partition_size <= aligned_comm_partition_size: - sub_partition_count = 1 - sub_partition_size = aligned_param_partition_size - else: - sub_partition_count = math.ceil(aligned_param_partition_size / - aligned_comm_partition_size) - sub_partition_size = aligned_comm_partition_size - - # Compute required padding for alignment to dp and max_elements_per_comm - padding = (sub_partition_count * sub_partition_size * dp) - num_elements - - log_dist( - f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}", - ranks=[0]) - log_dist( - f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}", - ranks=[0]) - - if padding == 0: - aligned_tensor_list = tensor_list - else: - pad_tensor = torch.zeros(padding, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - aligned_tensor_list = tensor_list + [pad_tensor] - - flat_tensors = _flatten_dense_tensors(aligned_tensor_list) - return flat_tensors - - def _single_range_check(current_index, start_index, end_index, tensor_size): offset = 0 if (current_index >= start_index) and (current_index < end_index): @@ -127,6 +79,11 @@ def __init__(self, max_elements_per_comm=5e8, elastic_checkpoint=True): + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + if dp_process_group is not None and partition_size is not None: raise ValueError("Cannot specify both dp_process_group " "and partition size") @@ -209,7 +166,7 @@ def __init__(self, # flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing # RS: create aligned sub-partitions - flat_aligned_params = flatten_dense_tensors_sub_partition_aligned( + flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned( tensor_list=self.fp16_groups[i], dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=self.max_elems_per_comm[i], @@ -218,8 +175,8 @@ def __init__(self, # TODO: I don't think this does anything? # set model fp16 weight to slices of flattened buffer - updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], - self.fp16_groups[i]) + updated_params = self.unflatten(self.fp16_groups_flat[i], + self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data @@ -455,8 +412,8 @@ def get_all_sub_partition_info(tensor_list, return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local - @staticmethod - def get_flat_sub_partitions(comm_tensor_list, + def get_flat_sub_partitions(self, + comm_tensor_list, comm_param_offsets, sub_partition_size, dtype, @@ -527,7 +484,7 @@ def get_flat_sub_partitions(comm_tensor_list, partition_params.append(my_params) #flat_tensor_list) final_param_offsets.append(my_offsets) assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets)) - flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list)) + flat_sub_partitions.append(self.flatten(flat_tensor_list)) if num_comm_intervals is not None and len( flat_sub_partitions) < num_comm_intervals: # logger.info("padding w. sub partitions to ensure uniform communication") @@ -569,6 +526,55 @@ def free_grad_in_param_list(self, param_list): else: p.grad = None + def flatten_dense_tensors_sub_partition_aligned(self, + tensor_list, + dp, + max_elements_per_comm, + pg): + assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}" + + num_elements = sum(t.numel() for t in tensor_list) + log_dist( + "Total number of elements in model: {}, max elements per com: {}".format( + num_elements, + max_elements_per_comm), + ranks=[0]) + + # Compute aligned partition size based on parameter count + aligned_param_partition_size = math.ceil(num_elements / dp) + + # Compute aligned partition size based on communication size + aligned_comm_partition_size = int(max_elements_per_comm // dp) + + if aligned_param_partition_size <= aligned_comm_partition_size: + sub_partition_count = 1 + sub_partition_size = aligned_param_partition_size + else: + sub_partition_count = math.ceil(aligned_param_partition_size / + aligned_comm_partition_size) + sub_partition_size = aligned_comm_partition_size + + # Compute required padding for alignment to dp and max_elements_per_comm + padding = (sub_partition_count * sub_partition_size * dp) - num_elements + + log_dist( + f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}", + ranks=[0]) + log_dist( + f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}", + ranks=[0]) + + if padding == 0: + aligned_tensor_list = tensor_list + else: + pad_tensor = torch.zeros(padding, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + aligned_tensor_list = tensor_list + [pad_tensor] + + flat_tensors = self.flatten(aligned_tensor_list) + return flat_tensors + def reduce_scatter_gradients(self, postscale_gradients, gradient_predivide_factor, @@ -699,8 +705,8 @@ def step(self, closure=None): # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): - updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], - self.fp16_groups[i]) + updated_params = self.unflatten(self.fp16_groups_flat[i], + self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data @@ -903,7 +909,7 @@ def _retrieve_group_sub_partition_weights(self, sub_partition_idx = (comm_idx * num_partitions) + rank all_sub_partition_weights[sub_partition_idx] = sub_partition_weights - flat_merged_weights = flatten_dense_tensors_sub_partition_aligned( + flat_merged_weights = self.flatten_dense_tensors_sub_partition_aligned( tensor_list=all_sub_partition_weights, dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=max_elems_per_comm, @@ -951,7 +957,7 @@ def _partition_base_optimizer_state(self, return all_partition_states[0] alignment = dist.get_world_size(group=self.dp_process_group) - flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned( + flat_merged_partitions = self.flatten_dense_tensors_sub_partition_aligned( tensor_list=all_partition_states, dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=max_elems_per_comm, diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index cd29625958c9..39d780e55574 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -3,7 +3,6 @@ ''' import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist import math @@ -16,9 +15,8 @@ from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.ops.adam import DeepSpeedCPUAdam - +from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.utils import logger -from ...ops.op_builder import UtilsBuilder #Toggle this to true to enable correctness test #with gradient partitioning and without @@ -52,28 +50,6 @@ def lcm(x, y): return x * y // gcd(x, y) -# create a flat tensor aligned at the alignment boundary -def flatten_dense_tensors_aligned(tensor_list, alignment): - num_elements = 0 - for tensor in tensor_list: - num_elements = num_elements + tensor.numel() - - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - - num_elements = num_elements + elements_to_add - else: - padded_tensor_list = tensor_list - - return _flatten_dense_tensors(padded_tensor_list) - - def get_alignment_padding(tensor_list, alignment): num_elements = sum([tensor.numel() for tensor in tensor_list]) remainder = num_elements % alignment @@ -121,11 +97,6 @@ def __init__(self, gradient_predivide_factor=1.0, gradient_accumulation_steps=1): - # Load pre-installed or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten - if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") @@ -143,6 +114,11 @@ def __init__(self, raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + self.timers = timers self.reduce_scatter = reduce_scatter @@ -236,7 +212,7 @@ def __init__(self, #create flat buffer in CPU and move to GPU self.fp16_groups_flat.append( - flatten_dense_tensors_aligned( + self.flatten_dense_tensors_aligned( self.fp16_groups[i], dist.get_world_size(group=self.dp_process_group)).cuda( torch.cuda.current_device())) @@ -247,8 +223,8 @@ def __init__(self, f"After Flattening and after emptying param group {i} cache") # set model fp16 weight to slices of flattened buffer - updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], - self.fp16_groups[i]) + updated_params = self.unflatten(self.fp16_groups_flat[i], + self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data @@ -611,6 +587,27 @@ def report_ipg_memory_usage(self, tag, param_elems): f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" ) + # create a flat tensor aligned at the alignment boundary + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tensor in tensor_list: + num_elements = num_elements + tensor.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + ############### Independent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: @@ -1004,7 +1001,7 @@ def are_all_related_partitions_reduced(params_id): self.param_dict[params_id].grad = None def flatten_and_print(self, message, tensors, start=0, n=5): - flatten_tensor = _flatten_dense_tensors(tensors) + flatten_tensor = self.flatten(tensors) def print_func(): logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) @@ -1327,7 +1324,7 @@ def get_flat_partition(self, if return_tensor_list: return flat_tensor_list - return _flatten_dense_tensors(flat_tensor_list) + return self.flatten(flat_tensor_list) def free_grad_in_param_list(self, param_list): for p in param_list: @@ -1419,14 +1416,13 @@ def step(self, closure=None): #create a flat gradients for parameters updated by this process # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - single_grad_partition = flatten_dense_tensors_aligned( + single_grad_partition = self.flatten_dense_tensors_aligned( self.averaged_gradients[i], int(self.partition_size[i])).to( self.single_partition_of_fp32_groups[i].dtype) else: - single_grad_partition = _flatten_dense_tensors( - self.averaged_gradients[i]).to( - self.single_partition_of_fp32_groups[i].dtype) + single_grad_partition = self.flatten(self.averaged_gradients[i]).to( + self.single_partition_of_fp32_groups[i].dtype) assert single_grad_partition.numel() == self.partition_size[i], \ "averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id) @@ -1507,8 +1503,8 @@ def step(self, closure=None): # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): - updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], - self.fp16_groups[i]) + updated_params = self.unflatten(self.fp16_groups_flat[i], + self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data @@ -1749,7 +1745,7 @@ def _restore_from_fp32_weights(self, all_state_dict): merged_partitions = [ sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict ] - flat_merged_partitions = flatten_dense_tensors_aligned( + flat_merged_partitions = self.flatten_dense_tensors_aligned( merged_partitions, dist.get_world_size(group=self.dp_process_group)) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) @@ -1773,7 +1769,7 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states): partition_id = dist.get_rank(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group) if torch.is_tensor(all_partition_states[0]): - flat_merged_partitions = flatten_dense_tensors_aligned( + flat_merged_partitions = self.flatten_dense_tensors_aligned( all_partition_states, alignment) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 493106e93239..c7eb4b5cfc7b 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -6,7 +6,6 @@ import os import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist import math @@ -18,26 +17,13 @@ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, ZeroParamType, _init_external_params, Init, is_zero_param from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.op_builder import UtilsBuilder import itertools # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False -try: - from apex_C import flatten - from apex_C import unflatten -except ImportError: - try: - _ = warned_flatten - except NameError: - logger.warning( - "apex was installed without --cpp_ext. Falling back to Python flatten and unflatten." - ) - warned_flatten = True - from torch._utils import _flatten_dense_tensors as flatten - from torch._utils import _unflatten_dense_tensors as unflatten - def print_rank_0(message, debug=False, force=False): if torch.distributed.get_rank() == 0 and (debug or force): @@ -71,28 +57,6 @@ def lcm(x, y): return x * y // gcd(x, y) -# create a flat tensor aligned at the alignment boundary -def flatten_dense_tensors_aligned(tensor_list, alignment): - num_elements = 0 - for tens in tensor_list: - num_elements = num_elements + tens.numel() - - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - - num_elements = num_elements + elements_to_add - else: - padded_tensor_list = tensor_list - - return _flatten_dense_tensors(padded_tensor_list) - - def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() @@ -598,6 +562,11 @@ def __init__(self, raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + if not all(is_zero_param(p) for p in module.parameters()): group = None if mpu: @@ -872,7 +841,7 @@ def _create_fp16_partitions(self): #create flat buffer in CPU and move to GPU self.fp16_partitioned_groups_flat.append( - flatten_dense_tensors_aligned( + self.flatten_dense_tensors_aligned( self.fp16_partitioned_groups[i], dist.get_world_size(group=self.dp_process_group)).cuda( torch.cuda.current_device())) @@ -883,7 +852,7 @@ def _create_fp16_partitions(self): #Without the detach, seems like the flattening becomes part of the #model graph causing errors downstream self.fp16_partitioned_groups_flat.append( - flatten_dense_tensors_aligned( + self.flatten_dense_tensors_aligned( self.fp16_partitioned_groups[i], dist.get_world_size( group=self.dp_process_group)).detach().pin_memory()) @@ -893,9 +862,8 @@ def _create_fp16_partitions(self): see_memory_usage(f"After Flattening param group {i}", force=False) #set model fp16 weight to slices of flattened buffer - updated_params = _unflatten_dense_tensors( - self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], + self.fp16_partitioned_groups[i]) for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): partitioned_param.data = q.data @@ -961,9 +929,9 @@ def _create_fp16_partitions_with_defragmentation(self): #create flat buffer in CPU and move to GPU self.fp16_partitioned_groups_flat.append( - flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i], - 1).cuda( - torch.cuda.current_device())) + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + 1).cuda(torch.cuda.current_device())) see_memory_usage( f"After flattening and moving param group {i} to GPU", force=False) @@ -1741,7 +1709,7 @@ def are_all_related_partitions_reduced(params_id): self.param_dict[params_id].grad = None def flatten_and_print(self, message, tensors, start=0, n=5): - flatten_tensor = _flatten_dense_tensors(tensors) + flatten_tensor = self.flatten(tensors) def print_func(): logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) @@ -1799,7 +1767,7 @@ def set_none_gradients_to_zero(self, i, partition_id): def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): rank = None - tensor = flatten(bucket) + tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -1829,7 +1797,7 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None): with torch.cuda.stream(self.reduction_stream): allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) if rank is None or rank == dist.get_rank(group=self.dp_process_group): - for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) def allreduce_no_retain(self, @@ -2048,7 +2016,7 @@ def get_flat_partition(self, if return_tensor_list: return flat_tensor_list - return _flatten_dense_tensors(flat_tensor_list) + return self.flatten(flat_tensor_list) def free_grad_in_param_list(self, param_list): for p in param_list: @@ -2158,9 +2126,8 @@ def old_step(self, closure=None): # create a flat gradients for parameters updated by this process # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors - single_grad_partition = _flatten_dense_tensors( - self.averaged_gradients[i]).to( - self.fp32_partitioned_groups_flat[i].dtype) + single_grad_partition = self.flatten(self.averaged_gradients[i]).to( + self.fp32_partitioned_groups_flat[i].dtype) assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[i].numel(), \ "averaged gradients have different number of elements that partition size {} {} {} {}".format( @@ -2174,11 +2141,10 @@ def old_step(self, closure=None): self.averaged_gradients[i] = None single_partition_grad_groups.append(single_grad_partition) - debug_fp32_grads[i] = [ - (t.clone().detach(), - t) for t in _unflatten_dense_tensors(single_grad_partition, - group) - ] + debug_fp32_grads[i] = [(t.clone().detach(), + t) + for t in self.unflatten(single_grad_partition, + group)] self.stop_timers([OPTIMIZER_FP32_GRADIENT]) @@ -2213,9 +2179,8 @@ def old_step(self, closure=None): #for p in self.fp16_groups[i]: # p.data=p.ds_tensor - updated_params = _unflatten_dense_tensors( - self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], + self.fp16_partitioned_groups[i]) for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): # print(f"Grad fn: {p.grad_fn}") # p.data = torch.ones(1).half().cuda() @@ -2269,9 +2234,8 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): partition_id = dist.get_rank(group=self.dp_process_group) - single_grad_partition = _flatten_dense_tensors( - self.averaged_gradients[sub_group_id]).to( - self.fp32_partitioned_groups_flat[sub_group_id].dtype) + single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( + self.fp32_partitioned_groups_flat[sub_group_id].dtype) assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ "averaged gradients have different number of elements that partition size {} {} {} {}".format( @@ -2302,10 +2266,30 @@ def _release_sub_group(self, sub_group_id, timer_names=set()): see_memory_usage(f'After release optimizer sub group {sub_group_id}', force=False) + # create a flat tensor aligned at the alignment boundary + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tens in tensor_list: + num_elements = num_elements + tens.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + def _unflatten_partitioned_parameters(self, sub_group_id): - updated_params = _unflatten_dense_tensors( - self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): partitioned_param.data = q.data @@ -2411,10 +2395,9 @@ def dump_post_step_gradients(self): for i, group in enumerate(self.fp16_groups): print( f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') - unflat_fp16 = _unflatten_dense_tensors(self.fp16_groups_flat[i], - self.fp16_groups[i]) - unflat_fp32 = _unflatten_dense_tensors(self.fp32_partitioned_groups_flat[i], - self.fp16_groups[i]) + unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) + unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], + self.fp16_groups[i]) for j, p in enumerate(self.fp16_groups[i]): param_id = self.get_param_id(p) param_norm = float(p.data.float().norm(2)) @@ -2599,8 +2582,7 @@ def _set_loss_scale(self, value): def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): # Remove paddings from flattened tensor - individual_tensors = _unflatten_dense_tensors(padded_flattened_tensor, - group_tensors) + individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') @@ -2721,14 +2703,14 @@ def _get_flattened_partition(self, all_partition_states): local_state_partitions = [] for param_index, param_slices in enumerate(param_partitions): - flattened_merged_tensor = flatten_dense_tensors_aligned( + flattened_merged_tensor = self.flatten_dense_tensors_aligned( param_slices, alignment) new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) local_state_partitions.append(new_partitions[partition_id]) if torch.is_tensor(local_state_partitions[0]): - return flatten_dense_tensors_aligned(local_state_partitions, alignment) + return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) # Assume non-tensor states are not partitioned and equal across ranks, so return first one return local_state_partitions[0] @@ -2783,7 +2765,7 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): # update fp16 unflattened params for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): - updated_params = _unflatten_dense_tensors( + updated_params = self.unflatten( self.fp16_partitioned_groups_flat[sub_group_id], self.fp16_partitioned_groups[sub_group_id])