diff --git a/README.md b/README.md
index 8a312ba219b56..d48c7a55389ee 100755
--- a/README.md
+++ b/README.md
@@ -15,14 +15,11 @@
## Latest News
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
-* ***[2023/04] π [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)*** [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[δΈζ](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[ζ₯ζ¬θͺ](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]π
+* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)
+* [2023/04] π [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[δΈζ](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[ζ₯ζ¬θͺ](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]π
* [2023/03] [Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE](https://www.deepspeed.ai/2023/03/30/multi-modal.html)
* [2023/02] [Automatic Tensor Parallelism: Enables tensor parallelism by default without an injection policy](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/)
* [2022/12] [DeepSpeed Data Efficiency: A composable library that makes better use of data, increases training efficiency, and improves model quality](https://www.deepspeed.ai/2022/12/11/data-efficiency.html)
-* [2022/11] [Stable Diffusion Image Generation under 1 second w. DeepSpeed MII](https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/benchmark/txt2img)
-* [2022/10] [DeepSpeed-MII: instant speedup on 24,000+ open-source DL models with up to 40x cheaper inference](https://www.deepspeed.ai/2022/10/10/mii.html)
-* [2022/09] [ZeRO-Inference: Democratizing massive model inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html)
-* [2022/07] [Azure and DeepSpeed empower easy-to-use and high-performance model training](https://azure.microsoft.com/en-us/blog/azure-empowers-easytouse-highperformance-and-hyperscale-model-training-using-deepspeed/)
---
diff --git a/csrc/includes/quantization.h b/csrc/includes/quantization.h
index 826797889ebbe..de87734137f8a 100644
--- a/csrc/includes/quantization.h
+++ b/csrc/includes/quantization.h
@@ -40,6 +40,32 @@ void launch_dequantize_kernel(T* dequant_data,
int total_elems,
cudaStream_t stream);
+void launch_swizzled_quant(int8_t* q_data,
+ float* q_scales,
+ const __half* input_data,
+ int num_bits,
+ quantize::Type q_type,
+ int groups,
+ int elems_per_group,
+ int pipelining,
+ int nodes,
+ int devices_per_node,
+ cudaStream_t stream);
+
+void launch_dequant_reduce(int8_t* reduced_data,
+ float* reduced_scales,
+ const int8_t* input_data,
+ const float* input_scales,
+ int num_gpus,
+ int num_bits,
+ quantize::Type quant_type,
+ int out_groups,
+ int elems_per_out_group,
+ int elems_per_in_tensor,
+ int groups_per_in_tensor,
+ int elems_per_in_group,
+ cudaStream_t stream);
+
template
void launch_fake_quantize_kernel(T* vals,
int total_count,
diff --git a/csrc/quantization/pt_binding.cpp b/csrc/quantization/pt_binding.cpp
index ccc0c15be1a66..2bc9f89bbee97 100644
--- a/csrc/quantization/pt_binding.cpp
+++ b/csrc/quantization/pt_binding.cpp
@@ -136,6 +136,95 @@ at::Tensor dequantize(at::Tensor& quantized_data,
return output;
}
+std::vector ds_swizzle_quant(at::Tensor& input_vals,
+ int groups,
+ int num_bits,
+ quantize::Type quant_type,
+ int pipeline_size,
+ int nodes,
+ int devices_per_node)
+{
+ auto scales_options = at::TensorOptions()
+ .dtype(at::kFloat)
+ .layout(at::kStrided)
+ .device(at::kCUDA)
+ .requires_grad(false);
+ const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
+ auto scales = torch::empty({groups, scales_elems}, scales_options);
+
+ auto output_options = at::TensorOptions()
+ .dtype(at::kChar)
+ .layout(at::kStrided)
+ .device(at::kCUDA)
+ .requires_grad(false);
+
+ const int quantization_scalar = 8 / num_bits;
+ const int compressed_vals = at::numel(input_vals) / quantization_scalar;
+
+ auto output = torch::empty({compressed_vals}, output_options);
+ const int elems_per_group = at::numel(input_vals) / groups;
+
+ launch_swizzled_quant((int8_t*)output.data_ptr(),
+ (float*)scales.data_ptr(),
+ (__half*)input_vals.data_ptr(),
+ num_bits,
+ quant_type,
+ groups,
+ elems_per_group,
+ pipeline_size,
+ nodes,
+ devices_per_node,
+ at::cuda::getCurrentCUDAStream());
+
+ return {output, scales};
+}
+
+std::vector quantized_reduction(at::Tensor& input_vals,
+ at::Tensor& input_scales,
+ int in_groups,
+ int out_groups,
+ int num_bits,
+ quantize::Type quant_type)
+{
+ auto scales_options = at::TensorOptions()
+ .dtype(at::kFloat)
+ .layout(at::kStrided)
+ .device(at::kCUDA)
+ .requires_grad(false);
+ const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
+ auto scales = torch::empty({out_groups, scales_elems}, scales_options);
+
+ auto output_options = at::TensorOptions()
+ .dtype(at::kChar)
+ .layout(at::kStrided)
+ .device(at::kCUDA)
+ .requires_grad(false);
+
+ std::vector sz(input_vals.sizes().begin(), input_vals.sizes().end());
+ const int gpu_per_node = 16; // depend on machine in_groups/out_groups;
+ sz[sz.size() - 1] = sz.back() / gpu_per_node; // num of GPU per nodes
+ const int elems_per_in_tensor = at::numel(input_vals) / gpu_per_node;
+ auto output = torch::empty(sz, output_options);
+
+ const int elems_per_in_group = elems_per_in_tensor / (in_groups / gpu_per_node);
+ const int elems_per_out_group = elems_per_in_tensor / out_groups;
+
+ launch_dequant_reduce((int8_t*)output.data_ptr(),
+ (float*)scales.data_ptr(),
+ (const int8_t*)input_vals.data_ptr(),
+ (const float*)input_scales.data_ptr(),
+ gpu_per_node,
+ num_bits,
+ quant_type,
+ out_groups,
+ elems_per_out_group,
+ elems_per_in_tensor,
+ in_groups / gpu_per_node,
+ elems_per_in_group,
+ at::cuda::getCurrentCUDAStream());
+ return {output, scales};
+}
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize, "DeepSpeed Quantize with fp32 (CUDA)");
@@ -158,4 +247,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize);
+ m.def("swizzle_quant", &ds_swizzle_quant);
+ m.def("quantized_reduction", &quantized_reduction);
}
diff --git a/csrc/quantization/quant_reduce.cu b/csrc/quantization/quant_reduce.cu
new file mode 100644
index 0000000000000..26db1118c831a
--- /dev/null
+++ b/csrc/quantization/quant_reduce.cu
@@ -0,0 +1,263 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include "dequantization_utils.h"
+#include "ds_kernel_utils.h"
+#include "memory_access_utils.h"
+#include "quantization_utils.h"
+#include "reduction_utils.h"
+
+using rop = reduce::ROpType;
+
+/*
+TODO(cmikeh2): Add implementation that better handles larger nodes. It would like make sense
+to leverage some parallel reductions here to improve performance.
+*/
+
+template
+__global__ void __launch_bounds__(1024) dequant_reduce(int8_t* reduced_data,
+ float* reduced_scales,
+ const int8_t* input_data,
+ const float* input_scales,
+ int elems_per_out_group,
+ int elems_per_in_tensor,
+ int groups_per_in_tensor,
+ int elems_per_in_group,
+ int num_tensors)
+{
+ cg::thread_block tb = cg::this_thread_block();
+ cg::thread_block_tile warp = cg::tiled_partition(tb);
+
+ // NOTE(cmikeh2): This probably could be hardcoded to a larger number,
+ // but that means even stronger restrictions on the number of elements per group
+ // A performance analysis here might be beneficial
+ constexpr int mem_granularity = (numBits == 8) ? 8 : 4;
+ constexpr int elems_per_load = mem_granularity / sizeof(int8_t); // div by 1
+ constexpr int storage_values = 16 / sizeof(__half2);
+
+ const int block_offset = tb.group_index().x * elems_per_out_group;
+ const int elem_offset = tb.thread_index().x * elems_per_load;
+ const int base_offset = block_offset + elem_offset;
+ const int stride = tb.group_dim().x * elems_per_load;
+
+ __half2 local_buffer[totalChunks * storage_values];
+
+ quantize::GroupStats stats;
+
+#pragma unroll
+ for (int i = 0; i < totalChunks; i++) {
+ __half2* iteration_buffer = local_buffer + i * storage_values;
+
+#pragma unroll
+ for (int j = 0; j < storage_values; j++) {
+ iteration_buffer[j] = reduce::init();
+ }
+
+ const int iter_offset = i * stride + base_offset;
+ const int iter_scale_idx = iter_offset / elems_per_in_group;
+ bool do_loads = i * stride + elem_offset < elems_per_out_group;
+
+ if (numTensors > 0) {
+#pragma unroll
+ for (int j = 0; j < numTensors; j++) {
+ if (do_loads) {
+ int8_t load_buffer[elems_per_load];
+
+ mem_access::load_global(
+ load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
+
+ quantize::Params params(
+ input_scales + j * groups_per_in_tensor, iter_scale_idx);
+
+ __half2 dequant_buffer[storage_values];
+ dequantize::chunk(dequant_buffer, load_buffer, params);
+
+#pragma unroll
+ for (int k = 0; k < storage_values; k++) {
+ iteration_buffer[k] =
+ reduce::element(iteration_buffer[k], dequant_buffer[k]);
+ }
+ }
+ }
+ } else {
+#pragma unroll 4
+ for (int j = 0; j < num_tensors; j++) {
+ if (do_loads) {
+ int8_t load_buffer[elems_per_load];
+
+ mem_access::load_global(
+ load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
+
+ quantize::Params params(
+ input_scales + j * groups_per_in_tensor, iter_scale_idx);
+
+ __half2 dequant_buffer[storage_values];
+ dequantize::chunk(dequant_buffer, load_buffer, params);
+
+#pragma unroll
+ for (int k = 0; k < storage_values; k++) {
+ iteration_buffer[k] =
+ reduce::element(iteration_buffer[k], dequant_buffer[k]);
+ }
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < storage_values; j++) { stats.update(iteration_buffer[j]); }
+ }
+
+ auto params = stats.template get_params(tb, warp);
+
+ if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }
+
+#pragma unroll
+ for (int i = 0; i < totalChunks; i++) {
+ const int iter_offset = i * stride + base_offset;
+ if (i * stride + elem_offset < elems_per_out_group) {
+ int8_t local_output[elems_per_load];
+ quantize::_chunk(
+ local_output, local_buffer + i * storage_values, params);
+ mem_access::store_global(reduced_data + iter_offset, local_output);
+ }
+ }
+}
+
+template
+int32_t pow2_round(int32_t raw_value)
+{
+ return (((raw_value - 1) >> Power) + 1) << Power;
+}
+
+#define LAUNCH_DEQUANT_REDUCE(num_chunks) \
+ dequant_reduce \
+ <<>>(reduced_data, \
+ reduced_scales, \
+ input_data, \
+ input_scales, \
+ elems_per_out_group, \
+ elems_per_in_tensor, \
+ groups_per_in_tensor, \
+ elems_per_in_group, \
+ num_tensors);
+
+template
+void launch_dequant_reduce_impl(int8_t* reduced_data,
+ float* reduced_scales,
+ const int8_t* input_data,
+ const float* input_scales,
+ int out_groups,
+ int elems_per_out_group,
+ int elems_per_in_tensor,
+ int groups_per_in_tensor,
+ int elems_per_in_group,
+ int num_tensors,
+ cudaStream_t stream)
+{
+ // This is a coincidence. This is derived by 8 halves per 16 bytes with 2-way packing for int4
+ constexpr int elems_per_thread = numBits;
+ const int one_step_threads =
+ next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread));
+ // TODO(cmikeh2): Tune this
+ const int threads = (one_step_threads < 1024) ? one_step_threads : 1024;
+
+ dim3 block(threads);
+ dim3 grid(out_groups);
+
+ const int elems_per_step = threads * elems_per_thread;
+ const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step;
+
+ const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw;
+
+ if (unroll == 1) {
+ // 0-4096 elems
+ LAUNCH_DEQUANT_REDUCE(1);
+ } else if (unroll == 2) {
+ // 4097-8192 etc...
+ LAUNCH_DEQUANT_REDUCE(2);
+ } else if (unroll == 3) {
+ LAUNCH_DEQUANT_REDUCE(3);
+ } else if (unroll == 4) {
+ LAUNCH_DEQUANT_REDUCE(4);
+ } else if (unroll == 6) {
+ LAUNCH_DEQUANT_REDUCE(6);
+ } else if (unroll == 8) {
+ LAUNCH_DEQUANT_REDUCE(8);
+ } else if (unroll == 10) {
+ LAUNCH_DEQUANT_REDUCE(10);
+ } else if (unroll == 12) {
+ // 48k limit
+ LAUNCH_DEQUANT_REDUCE(12);
+ } else {
+ assert(false);
+ }
+}
+
+#define LAUNCH_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
+ launch_dequant_reduce_impl(reduced_data, \
+ reduced_scales, \
+ input_data, \
+ input_scales, \
+ out_groups, \
+ elems_per_out_group, \
+ elems_per_in_tensor, \
+ groups_per_in_tensor, \
+ elems_per_in_group, \
+ num_gpus, \
+ stream);
+
+void launch_dequant_reduce(int8_t* reduced_data,
+ float* reduced_scales,
+ const int8_t* input_data,
+ const float* input_scales,
+ int num_gpus,
+ int num_bits,
+ quantize::Type quant_type,
+ int out_groups,
+ int elems_per_out_group,
+ int elems_per_in_tensor,
+ int groups_per_in_tensor,
+ int elems_per_in_group,
+ cudaStream_t stream)
+{
+ if (quant_type == quantize::Type::Symmetric) {
+ if (num_bits == 4) {
+ if (num_gpus == 8) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric);
+ } else if (num_gpus == 16) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric);
+ } else {
+ LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric);
+ }
+ } else if (num_bits == 8) {
+ if (num_gpus == 8) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric);
+ } else if (num_gpus == 16) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric);
+ } else {
+ LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric);
+ }
+ }
+ } else if (quant_type == quantize::Type::Asymmetric) {
+ if (num_bits == 4) {
+ if (num_gpus == 8) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric);
+ } else if (num_gpus == 16) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric);
+ } else {
+ LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric);
+ }
+ } else if (num_bits == 8) {
+ if (num_gpus == 8) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric);
+ } else if (num_gpus == 16) {
+ LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric);
+ } else {
+ LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric);
+ }
+ }
+ }
+}
diff --git a/csrc/quantization/swizzled_quantize.cu b/csrc/quantization/swizzled_quantize.cu
new file mode 100644
index 0000000000000..5a02a0ae8120a
--- /dev/null
+++ b/csrc/quantization/swizzled_quantize.cu
@@ -0,0 +1,196 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "memory_access_utils.h"
+#include "quantization_utils.h"
+#include "reduction_utils.h"
+
+using rop = reduce::ROpType;
+
+namespace swiz_quant {
+constexpr int max_threads = 512;
+constexpr int min_threads = 32;
+
+constexpr int step_granularity = 2;
+constexpr int h_per_step = step_granularity * quantize::h_per_load;
+} // namespace swiz_quant
+
+template
+__global__ void swizzled_quant_kernel(int8_t* quantized_data,
+ float* quantized_scales,
+ const __half* uncompressed_data,
+ int elems_per_group,
+ int nodes,
+ int devices_per_node)
+{
+ cg::thread_block tb = cg::this_thread_block();
+ cg::thread_block_tile warp = cg::tiled_partition(tb);
+
+ // Indexing offsets, same as normal quantization for in-case
+ const int block_rank = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
+ const int block_offset = block_rank * elems_per_group;
+ const int elem_offset = tb.thread_index().x * quantize::h_per_load;
+ const int base_offset = block_offset + elem_offset;
+ const int stride = tb.size() * quantize::h_per_load;
+ const __half* input_base = uncompressed_data + base_offset;
+
+ // Local buffer
+ __half2 local_buffer[totalChunks * quantize::h2_per_load];
+
+ quantize::GroupStats stats;
+#pragma unroll
+ for (int i = 0; i < totalChunks; i++) {
+ __half2* iteration_buffer = local_buffer + i * quantize::h2_per_load;
+
+ mem_access::load_global(
+ iteration_buffer, input_base + i * stride, elem_offset + i * stride < elems_per_group);
+
+#pragma unroll
+ for (int j = 0; j < quantize::h2_per_load; j++) { stats.update(iteration_buffer[j]); }
+ }
+
+ auto params = stats.template get_params(tb, warp);
+
+ const int partition_id = blockIdx.z;
+ const int partition_offset = partition_id / devices_per_node;
+ const int partition_base = (partition_id % devices_per_node) * nodes;
+ const int pipelining_offset = blockIdx.y * (devices_per_node * nodes);
+ const int output_partition = (pipelining_offset + partition_base + partition_offset);
+
+ constexpr int out_scalar_effect = 8 / numBits;
+ const int out_block_rank = output_partition * gridDim.x + blockIdx.x;
+ const int out_block_offset = out_block_rank * elems_per_group / out_scalar_effect;
+ const int out_base_offset = out_block_offset + elem_offset / out_scalar_effect;
+ int8_t* out_base = quantized_data + out_base_offset;
+
+ const int out_stride = stride / out_scalar_effect;
+ constexpr int num_int8_out = quantize::h_per_load / out_scalar_effect;
+
+ if (tb.thread_index().x == 0) { params.store(quantized_scales, out_block_rank); }
+
+#pragma unroll
+ for (int i = 0; i < totalChunks; i++) {
+ if (i * stride + elem_offset < elems_per_group) {
+ int8_t local_output[quantize::h_per_load / out_scalar_effect];
+ quantize::_chunk(
+ local_output, local_buffer + i * quantize::h2_per_load, params);
+ mem_access::store_global(out_base + i * out_stride, local_output);
+ }
+ }
+}
+
+#define LAUNCH_SWIZZLE_QUANT(total_chunks, threads) \
+ swizzled_quant_kernel<<>>( \
+ q_data, q_scales, input_data, elems_per_group, nodes, devices_per_node);
+
+/*
+Swizzled quantization reorganizes the quantized groups in order to better facilitate
+communication. As an example of the partitioning scheme we have the following example
+of 2 node, 4 device swizzling:
+
+ --- --- --- --- --- --- --- ---
+| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+ --- --- --- --- --- --- --- ---
+becomes
+ --- --- --- --- --- --- --- ---
+| 0 | 4 | 1 | 5 | 2 | 6 | 3 | 7 |
+ --- --- --- --- --- --- --- ---
+
+Multiple quantization groups may be mapped into a single partition. In order to better support
+later pipelining, we may also perform an additional slicing. In two-way slicing, for instance,
+the first halves of each partition are concatenated.
+*/
+
+template
+void launch_swizzled_quant_impl(int8_t* q_data,
+ float* q_scales,
+ const __half* input_data,
+ int groups,
+ int elems_per_group,
+ int pipelining,
+ int nodes,
+ int devices_per_node,
+ cudaStream_t stream)
+{
+ const int one_step_threads =
+ next_pow2((elems_per_group + swiz_quant::h_per_step - 1) / swiz_quant::h_per_step);
+ const int max_threads = (one_step_threads < swiz_quant::max_threads) ? one_step_threads
+ : swiz_quant::max_threads;
+ const int threads = (max_threads < swiz_quant::min_threads) ? swiz_quant::min_threads
+ : max_threads;
+
+ dim3 block(threads);
+ const int groups_per_partition = groups / (nodes * devices_per_node);
+ assert(groups_per_partition % pipelining == 0);
+ const int contiguous_groups = groups_per_partition / pipelining;
+ const int partitions = nodes * devices_per_node;
+ dim3 grid(contiguous_groups, pipelining, partitions);
+
+ const int elems_per_step = threads * swiz_quant::h_per_step;
+ const int external_unroll = ((elems_per_group + elems_per_step - 1) / elems_per_step);
+ const int total_unroll = external_unroll * swiz_quant::step_granularity;
+
+ assert(total_unroll % 2 == 0);
+
+ if (threads == 32) {
+ LAUNCH_SWIZZLE_QUANT(2, 32);
+ } else if (threads == 64) {
+ LAUNCH_SWIZZLE_QUANT(2, 64);
+ } else if (threads == 128) {
+ LAUNCH_SWIZZLE_QUANT(2, 128);
+ } else if (threads == 256) {
+ LAUNCH_SWIZZLE_QUANT(2, 256);
+ } else if (threads == 512) {
+ if (total_unroll == 2) {
+ LAUNCH_SWIZZLE_QUANT(2, 512);
+ } else if (total_unroll == 4) {
+ LAUNCH_SWIZZLE_QUANT(4, 512);
+ } else if (total_unroll == 6) {
+ LAUNCH_SWIZZLE_QUANT(6, 512);
+ } else if (total_unroll == 8) {
+ LAUNCH_SWIZZLE_QUANT(8, 512);
+ } else if (total_unroll == 10) {
+ LAUNCH_SWIZZLE_QUANT(10, 512);
+ }
+ }
+}
+
+#define DISPATCH_SWIZZLE_QUANT(num_bits, qtype) \
+ launch_swizzled_quant_impl(q_data, \
+ q_scales, \
+ input_data, \
+ groups, \
+ elems_per_group, \
+ pipelining, \
+ nodes, \
+ devices_per_node, \
+ stream);
+
+void launch_swizzled_quant(int8_t* q_data,
+ float* q_scales,
+ const __half* input_data,
+ int num_bits,
+ quantize::Type q_type,
+ int groups,
+ int elems_per_group,
+ int pipelining,
+ int nodes,
+ int devices_per_node,
+ cudaStream_t stream)
+{
+ if (num_bits == 4) {
+ if (q_type == quantize::Type::Asymmetric) {
+ DISPATCH_SWIZZLE_QUANT(4, quantize::Type::Asymmetric);
+ } else if (q_type == quantize::Type::Symmetric) {
+ DISPATCH_SWIZZLE_QUANT(4, quantize::Type::Symmetric);
+ }
+ } else if (num_bits == 8) {
+ if (q_type == quantize::Type::Asymmetric) {
+ DISPATCH_SWIZZLE_QUANT(8, quantize::Type::Asymmetric);
+ } else if (q_type == quantize::Type::Symmetric) {
+ DISPATCH_SWIZZLE_QUANT(8, quantize::Type::Symmetric);
+ }
+ }
+}
diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py
index b4a9a962b7824..dddcd18dc45da 100644
--- a/deepspeed/comm/comm.py
+++ b/deepspeed/comm/comm.py
@@ -557,6 +557,21 @@ def get_global_rank(group=None, group_rank=0):
return cdb.get_global_rank(group, group_rank)
+def get_all_ranks_from_group(group=None):
+ global cdb
+ assert cdb is not None and cdb.is_initialized(
+ ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ rank = 0
+ group_ranks = []
+ try:
+ while True:
+ group_ranks.append(cdb.get_global_rank(group, rank))
+ rank += 1
+ except RuntimeError:
+ pass
+ return group_ranks
+
+
# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py
index 208416c533aa1..64ce31d648108 100644
--- a/deepspeed/comm/torch.py
+++ b/deepspeed/comm/torch.py
@@ -8,8 +8,13 @@
from .utils import *
from .backend import *
from .comm import *
+import os
-import torch
+DS_COMM_ALL_GATHER_OFF = False
+DS_COMM_REDUCE_SCATTER_OFF = False
+DS_COMM_BROADCAST_OFF = False
+DS_COMM_ALL_REDUCE_OFF = False
+DS_COMM_REDUCE_OFF = False
def is_torch_two():
@@ -47,6 +52,46 @@ def get_coalescing_manager(group, device, reqs):
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
+##Utilities to turn comm off
+##TODO: move to base comm (wrapper)
+def all_gather_comm_off(flag=False):
+ global DS_COMM_ALL_GATHER_OFF
+ DS_COMM_ALL_GATHER_OFF = flag
+
+
+def reduce_scatter_comm_off(flag=False):
+ global DS_COMM_REDUCE_SCATTER_OFF
+ DS_COMM_REDUCE_SCATTER_OFF = flag
+
+
+def broadcast_comm_off(flag=False):
+ global DS_COMM_BROADCAST_OFF
+ DS_COMM_BROADCAST_OFF = flag
+
+
+def all_reduce_comm_off(flag=False):
+ global DS_COMM_ALL_REDUCE_OFF
+ DS_COMM_ALL_REDUCE_OFF = flag
+
+
+def reduce_comm_off(flag=False):
+ global DS_COMM_REDUCE_OFF
+ DS_COMM_REDUCE_OFF = flag
+
+
+#assumption: all_gather and reduce scatter
+## are what we care about
+def backward_comm_off(flag=False):
+ all_gather_comm_off(flag)
+ reduce_scatter_comm_off(flag)
+
+
+class Noop:
+
+ def wait(self):
+ return None
+
+
class TorchBackend(Backend):
"""
A light-weight wrapper class for torch.distributed API.
@@ -116,20 +161,39 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
+ if DS_COMM_REDUCE_OFF:
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.warning("REDUCE is OFF")
+ return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
- return torch.distributed.reduce_scatter(output=output,
- input_list=input_list,
- op=self._reduce_op(op),
- group=group,
- async_op=async_op)
+ if DS_COMM_REDUCE_SCATTER_OFF:
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.warning("REDUCE SCATTER is OFF")
+ return Noop()
+ else:
+ return torch.distributed.reduce_scatter(output=output,
+ input_list=input_list,
+ op=self._reduce_op(op),
+ group=group,
+ async_op=async_op)
def broadcast(self, tensor, src, group=None, async_op=False):
- return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
+ if DS_COMM_BROADCAST_OFF:
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.warning("BROADCAST is OFF")
+ return Noop()
+ else:
+ return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
- return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
+ if DS_COMM_ALL_GATHER_OFF:
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.warning("All Gather is OFF")
+ return Noop()
+ else:
+ return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
@@ -137,11 +201,23 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_
input_tensor=input_tensor,
group=group,
async_op=async_op)
+
+ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
+ if DS_COMM_ALL_GATHER_OFF:
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.warning("All Gather is OFF")
+ return Noop()
else:
- utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
- "torch.distributed.all_gather which will result in suboptimal performance. "
- "please consider upgrading your pytorch installation.")
- pass
+ if self.has_allgather_base:
+ return torch.distributed.distributed_c10d._all_gather_base(output_tensor=output_tensor,
+ input_tensor=input_tensor,
+ group=group,
+ async_op=async_op)
+ else:
+ utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
+ "torch.distributed.reduce_scatter which will result in suboptimal performance. "
+ "please consider upgrading your pytorch installation.")
+ pass
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py
index 36f0cb80781f5..0e23476064bd5 100644
--- a/deepspeed/runtime/comm/coalesced_collectives.py
+++ b/deepspeed/runtime/comm/coalesced_collectives.py
@@ -9,19 +9,62 @@
import math
from typing import List
-
import torch
from torch import Tensor
from deepspeed import comm as dist
# NOTE: Use torch.distributed's ProcessGroup class until we have our own.
-from torch.distributed import ProcessGroup
-import torch.nn.functional
-
+from torch.distributed import ProcessGroup, all_to_all_single
+from deepspeed.accelerator import get_accelerator
from deepspeed.utils import instrument_w_nvtx
+from deepspeed.ops import op_builder
def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group=None, async_op=False, prof=False):
- return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=async_op)
+ return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=False)
+
+
+quantizer_module = None
+
+
+@instrument_w_nvtx
+@torch.no_grad()
+def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]:
+ global quantizer_module
+ if quantizer_module is None:
+ quantizer_module = op_builder.QuantizerBuilder().load()
+ local_world_size = get_accelerator().device_count()
+ global_world_size = dist.get_world_size()
+ num_nodes = global_world_size // local_world_size
+ this_rank = dist.get_rank()
+ intra_idx = int(this_rank / local_world_size)
+ inter_idx = this_rank % local_world_size
+ output_lst: List[Tensor] = [None] * len(tensors)
+ for idx, tensor in enumerate(tensors):
+ if tensor.dim() == 1:
+ intra_quant_group = global_world_size
+ output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
+ continue
+ else:
+ intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size)
+
+ inter_quant_group = intra_quant_group // local_world_size
+ intra_quant_int4, intra_q_scales = quantizer_module.swizzle_quant(tensor, intra_quant_group, 4,
+ quantizer_module.Symmetric, 1, num_nodes,
+ local_world_size)
+ local_output = torch.empty_like(intra_quant_int4)
+ scale_output = torch.empty_like(intra_q_scales)
+ all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}'])
+ all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}'])
+ global_input_tensor, global_scales = quantizer_module.quantized_reduction(
+ local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric)
+ global_output = torch.empty_like(global_input_tensor)
+ global_scale_output = torch.empty_like(global_scales)
+ all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}'])
+ all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}'])
+ final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(),
+ 4, quantizer_module.Symmetric)
+ output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1)
+ return output_lst
@instrument_w_nvtx
@@ -32,7 +75,6 @@ def reduce_scatter_coalesced(
) -> List[Tensor]:
"""simultaneously reduce-scatter a list of tensors - this can be done more
efficiently than individual reduce scatter calls
-
TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
"""
this_rank = dist.get_rank(group)
@@ -87,5 +129,4 @@ def reduce_scatter_coalesced(
0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
offset += padded_partition_sz_for_each_tensor[tensor_idx]
-
return output_lst
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 080998e742d7e..f408518586426 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -203,6 +203,7 @@ def __init__(
self.training_data = training_data
self.collate_fn = collate_fn
self.mpu = mpu
+ self.all_to_all_group = None
self.data_parallel_group = None
self.global_steps = 0
self.global_samples = 0
@@ -810,6 +811,15 @@ def zero_allgather_partitions(self):
def zero_round_robin_gradients(self):
return self._config.zero_config.round_robin_gradients
+ def zero_hpz_partition_size(self):
+ return self._config.zero_config.zero_hpz_partition_size
+
+ def zero_quantized_weights(self):
+ return self._config.zero_config.zero_quantized_weights
+
+ def zero_quantized_gradients(self):
+ return self._config.zero_config.zero_quantized_gradients
+
def dump_state(self):
return self._config.dump_state
@@ -1074,6 +1084,10 @@ def _configure_distributed_model(self, model):
module.set_deepspeed_parallelism()
# Query the groups module to get information about various parallel groups
+ self.local_all_to_all_group = None
+ if self.zero_quantized_gradients():
+ log_dist("Using quantized gradients", ranks=[0])
+ self.local_all_to_all_group = groups._get_local_all_to_all_group()
self.data_parallel_group = groups._get_data_parallel_group()
self.dp_world_size = groups._get_data_parallel_world_size()
self.mp_world_size = groups._get_model_parallel_world_size()
@@ -1449,6 +1463,10 @@ def _configure_zero_optimizer(self, optimizer):
assert not self.has_moe_layers, "MoE not supported with Stage 3"
if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0])
+ zpg = groups._get_zero_param_intra_parallel_group()
+ if self.zero_hpz_partition_size() > 1 and zpg is None:
+ self._set_zero_group_parallelism()
+ zpg = groups._get_zero_param_intra_parallel_group()
optimizer = DeepSpeedZeRoOffload(self.module,
timers=timers,
ds_config=self.config,
@@ -1459,7 +1477,9 @@ def _configure_zero_optimizer(self, optimizer):
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
- mpu=self.mpu)
+ mpu=self.mpu,
+ zero_param_parallel_group=zpg,
+ zero_quantized_weights=self.zero_quantized_weights())
else:
log_dist(
f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
@@ -1488,6 +1508,7 @@ def _configure_zero_optimizer(self, optimizer):
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
+ all2all_process_group=self.local_all_to_all_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
@@ -1498,7 +1519,9 @@ def _configure_zero_optimizer(self, optimizer):
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
- communication_data_type=self.communication_data_type)
+ communication_data_type=self.communication_data_type,
+ zero_hpz_partition_size=self.zero_hpz_partition_size(),
+ zero_quantized_weights=self.zero_quantized_weights())
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py
index 30d4ea3d4698f..55f933e789830 100644
--- a/deepspeed/runtime/zero/config.py
+++ b/deepspeed/runtime/zero/config.py
@@ -35,6 +35,9 @@
"offload_optimizer": {...},
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false],
+ "zero_hpz_partition_size": 1,
+ "zero_quantized_weights": [true|false],
+ "zero_quantized_gradients": [true|false],
"memory_efficient_linear": [true|false]
}
}
@@ -248,6 +251,20 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Performance benefit grows with gradient accumulation steps (more copying
between optimizer steps) or GPU count (increased parallelism).
"""
+ zero_hpz_partition_size: int = Field(1, ge=0)
+ """
+ Number of ranks in zero parameters partitioning secondary group
+ """
+ zero_quantized_weights: bool = False
+ """
+ Boolean indicating whether to quantized zero parameters (weights)
+ for efficient all_gather comm
+ """
+ zero_quantized_gradients: bool = False
+ """
+ Boolean indicating whether to use quantized zero gradients
+ for efficient all_2_all_reduce comm
+ """
mics_shard_size: int = Field(-1, new_param="mics_shard_size")
diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py
index b32e9e6fae1b7..6a22b8bc74658 100755
--- a/deepspeed/runtime/zero/mics.py
+++ b/deepspeed/runtime/zero/mics.py
@@ -401,9 +401,23 @@ def __init__(self,
self.dp_process_group = first_param.comm.param_shard_group
self.partition_count = first_param.comm.param_shard_size
- def initialize_ds_offload(self, module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
- max_live_parameters, param_persistence_threshold, model_persistence_threshold,
- offload_param_config, mpu):
+ def initialize_ds_offload(
+ self,
+ module,
+ timers,
+ ds_config,
+ overlap_comm,
+ prefetch_bucket_size,
+ max_reuse_distance,
+ max_live_parameters,
+ param_persistence_threshold,
+ model_persistence_threshold,
+ offload_param_config,
+ mpu,
+ zpg=None,
+ zero_quantized_weights=False,
+ ):
+ assert not zero_quantized_weights and zpg is None, "MiCS is mutually exclusive with ZeRO++"
return MiCS_Offload(module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
max_live_parameters, param_persistence_threshold, model_persistence_threshold,
offload_param_config, mpu)
diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py
index f0ed5013a3ea7..5e838f32b3b3a 100644
--- a/deepspeed/runtime/zero/parameter_offload.py
+++ b/deepspeed/runtime/zero/parameter_offload.py
@@ -211,16 +211,21 @@ def __init__(self,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
offload_param_config=None,
- mpu=None):
+ mpu=None,
+ zero_param_parallel_group=None,
+ zero_quantized_weights=False):
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False)
self.module = module
+ self.timers = timers
self.dtype = list(module.parameters())[0].dtype
self.offload_device = None
self.offload_param_pin_memory = False
+ self.zero_param_parallel_group = zero_param_parallel_group
+ self.zero_quantized_weights = zero_quantized_weights
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_device = offload_param_config.device
@@ -276,6 +281,7 @@ def get_param_coordinator(self, training):
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry,
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
+ timers=self.timers,
)
return self.param_coordinators[training]
@@ -300,7 +306,9 @@ def _convert_to_zero_parameters(self, ds_config, module, mpu):
config_dict_or_path=ds_config,
remote_device=self.offload_device,
pin_memory=self.offload_param_pin_memory,
- mpu=mpu)
+ mpu=mpu,
+ zero_param_parallel_group=self.zero_param_parallel_group,
+ zero_quantized_weights=self.zero_quantized_weights)
def destroy(self):
self._remove_module_hooks()
@@ -340,7 +348,7 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):
persistent_params = []
total_persistent_parameters = 0
params_count = 0
- for _, param in self.module.named_parameters(recurse=True):
+ for name, param in self.module.named_parameters(recurse=True):
if param.ds_numel + total_persistent_parameters > model_threshold:
continue
@@ -480,7 +488,7 @@ def pre_sub_module_forward_function(self, sub_module):
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
- param_coordinator.fetch_sub_module(sub_module)
+ param_coordinator.fetch_sub_module(sub_module, forward=True)
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
@@ -490,7 +498,7 @@ def post_sub_module_forward_function(self, sub_module):
force=False)
param_coordinator = self.get_param_coordinator(training=sub_module.training)
- param_coordinator.release_sub_module(sub_module)
+ param_coordinator.release_sub_module(sub_module, backward=False)
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
@@ -502,7 +510,7 @@ def pre_sub_module_backward_function(self, sub_module):
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
- param_coordinator.fetch_sub_module(sub_module)
+ param_coordinator.fetch_sub_module(sub_module, forward=False)
@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
@@ -511,7 +519,7 @@ def post_sub_module_backward_function(self, sub_module):
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
- self.get_param_coordinator(training=True).release_sub_module(sub_module)
+ self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True)
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py
index 6b12e973ebcfb..3d9d9bdfb9099 100755
--- a/deepspeed/runtime/zero/partition_parameters.py
+++ b/deepspeed/runtime/zero/partition_parameters.py
@@ -20,6 +20,7 @@
from .linear import zero3_linear_wrap
+from deepspeed.utils import groups
import deepspeed
from ..utils import get_only_unique_item, see_memory_usage
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
@@ -537,16 +538,21 @@ def shutdown_init_context():
class AllGatherHandle:
- def __init__(self, handle, param: Parameter) -> None:
+ def __init__(self, handle, param: Parameter, quantization=None) -> None:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to be available")
- self.handle = handle
- self.param = param
+ self.__handle = handle
+ self.__param = param
+ self.__quantization = quantization
def wait(self) -> None:
- instrument_w_nvtx(self.handle.wait)()
- self.param.ds_status = ZeroParamStatus.AVAILABLE
+ instrument_w_nvtx(self.__handle.wait)()
+ if self.__quantization:
+ instrument_w_nvtx(self.__quantization.quant_handle.wait)()
+ self.__param.data = self.__quantization.backend.dequantize(
+ self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)
+ self.__param.ds_status = ZeroParamStatus.AVAILABLE
class AllGatherCoalescedHandle:
@@ -557,14 +563,18 @@ def __init__(
params: List[Parameter],
partitions: List[Tensor],
world_size: int,
+ use_secondary_tensor=False,
+ forward=False,
+ quantization=None,
) -> None:
- # renaming the fields without double underscore to ease
- # the class inheritance
self.allgather_handle = allgather_handle
self.params = params
self.partitions = partitions
self.world_size = world_size
+ self.use_secondary_tensor = use_secondary_tensor
+ self.forward = forward
self.complete = False
+ self.quantization = quantization
for param in self.params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
@@ -577,16 +587,29 @@ def wait(self) -> None:
instrument_w_nvtx(self.allgather_handle.wait)()
+ if self.quantization:
+ instrument_w_nvtx(self.quantization.quant_handle.wait)()
+ flat_tensor = self.quantization.backend.dequantize(
+ self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device)
+
+ self.partitions: List[Parameter] = []
+ for i in range(self.quantization.world_size):
+ self.partitions.append(
+ flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz))
+
# split the single tensor out into individual tensors
param_offset = 0
for param in self.params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
partitions: List[Tensor] = []
+ ds_tensor_numel = param.ds_tensor.ds_numel
+ if self.use_secondary_tensor and not self.forward:
+ ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups
for rank in range(self.world_size):
- param_start = rank * param.ds_tensor.ds_numel
+ param_start = rank * ds_tensor_numel
if param_start < param.ds_numel:
- part_to_copy = self.partitions[rank].narrow(
- 0, param_offset, min(param.ds_numel - param_start, param.ds_tensor.ds_numel))
+ part_to_copy = self.partitions[rank].narrow(0, param_offset,
+ min(param.ds_numel - param_start, ds_tensor_numel))
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE
@@ -594,11 +617,59 @@ def wait(self) -> None:
for part_to_copy in partitions:
part_to_copy.record_stream(get_accelerator().current_stream())
- param_offset += param.ds_tensor.ds_numel
+ param_offset += ds_tensor_numel
self.complete = True
+class QuantizationInfo:
+ # a placeholder object to store all quant related vars used in handles
+ def __init__(self) -> None:
+ self.quantized_param = None
+ self.backend = None
+ self.quant_handle = None
+ self.scale_buffer = None
+
+
+class CUDAQuantizer:
+ async_flag = True
+ target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k
+ group_size_cache = dict()
+
+ def __init__(self):
+ self.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
+
+ def quantize(self, param, groups=None):
+ if groups is None:
+ try:
+ groups = self.group_size_cache[param.numel()]
+ except KeyError:
+ groups = math.ceil(param.numel() / self.target_group_size)
+ while groups < param.numel():
+ if param.numel() % (8 * groups) == 0:
+ break
+ groups += 1
+ while True:
+ if param.numel() % (8 * groups * 2) == 0 and param.numel(
+ ) / groups > self.target_group_size: #hard limit of 16k group_size
+ groups *= 2
+ else:
+ break
+ assert (
+ param.numel() % (8 * groups) == 0
+ ), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}"
+ assert (param.numel() / groups < 16000), f"{param.numel()} / {groups} is larger than 16k"
+ assert param.numel(
+ ) > groups, f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}"
+ self.group_size_cache[param.numel()] = groups
+ return self.quantizer_cuda_module.quantize(param.to(get_accelerator().device_name()), groups, 8,
+ self.quantizer_cuda_module.Symmetric)
+
+ def dequantize(self, quantized_param, scale):
+ return self.quantizer_cuda_module.dequantize(quantized_param, scale, scale.numel(), 8,
+ self.quantizer_cuda_module.Symmetric)
+
+
def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle:
for param in params:
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
@@ -631,7 +702,9 @@ def __init__(self,
config=None,
enabled=True,
dtype=None,
- mpu=None):
+ mpu=None,
+ zero_param_parallel_group=None,
+ zero_quantized_weights=False):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision.
@@ -660,6 +733,8 @@ def __init__(self,
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
+ zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params.
+ zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False``
This context accelerates model initialization and enables models that
are too large to allocate in their entirety in CPU memory. It has the
@@ -750,17 +825,42 @@ def get_model():
self.rank = dist.get_rank(group=self.ds_process_group)
self.dp_world_size = dist.get_world_size(group=self.ds_process_group)
+ self.zero_param_process_group = zero_param_parallel_group
+ if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None:
+ groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size)
+ self.zero_param_process_group = groups._get_zero_param_intra_parallel_group()
+
+ self.num_ranks_in_param_group = self.dp_world_size
+ self.rank_in_group = self.rank
+ self.num_param_groups = 1
+
+ if self.zero_param_process_group is not None:
+ self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size()
+ self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group)
+ self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup()
+ print_rank_0(f"hpZeRO group size? {self.num_ranks_in_param_group}", force=True)
+
+ logger.debug(
+ "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} "
+ .format(self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks()))
+
# Local device is the device where the parameters are consumed, must be default device.
# It is the device where parameters are fully instantiated using allgather
self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
get_accelerator().set_device(self.local_device)
- if _ds_config is not None:
- self._update_persist_config(_ds_config)
+ self.quantized_weights = zero_quantized_weights
+ if _ds_config is not None and _ds_config.zero_config.zero_quantized_weights and not self.quantized_weights:
+ self.quantized_weights = _ds_config.zero_config.zero_quantized_weights
+
+ self.module = module
+ if (self.quantized_weights):
+ self.quantizer_module = CUDAQuantizer()
+ print_rank_0(f'Using quantizer: {self.quantizer_module.__class__.__name__}', force=True)
- if _ds_config.zero_config.offload_param is not None:
- remote_device = _ds_config.zero_config.offload_param.device
- pin_memory = _ds_config.zero_config.offload_param.pin_memory
+ if _ds_config is not None and _ds_config.zero_config.offload_param is not None:
+ remote_device = _ds_config.zero_config.offload_param.device
+ pin_memory = _ds_config.zero_config.offload_param.pin_memory
self._validate_remote_device(remote_device, _ds_config)
@@ -875,6 +975,14 @@ def _convert_to_deepspeed_param(self, param):
# The group that the parameter is scattered across.
param.ds_process_group = self.ds_process_group
+ # Stores the secondary partitioned copy of the tensor
+ param.ds_secondary_tensor = None
+
+ #Process group for secondary partition all (group) gather
+ param.ds_zero_param_process_group = self.zero_param_process_group
+ param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group
+ param.ds_secondary_tensor_num_of_groups = self.num_param_groups
+
# This is set to the Async Param swapper if remote device is nvme
# else this is set to None
param.nvme_swapper = self.param_swapper
@@ -890,11 +998,17 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
@instrument_w_nvtx
- def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -> AllGatherCoalescedHandle:
+ def all_gather_coalesced(params: Iterable[Parameter],
+ forward: bool,
+ safe_mode: bool = False) -> AllGatherCoalescedHandle:
# fetches from nvme if the partition is not available and in nvme
self._ensure_availability_of_partitioned_params(params)
+ quant = self.quantized_weights
+ if self.module is not None and self.module.training is False:
+ quant = False
+
if self.num_partitions == 1:
return _no_gather_coalesced(params)
@@ -903,6 +1017,17 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -
raise RuntimeError(param.ds_summary())
param.ds_status = ZeroParamStatus.INFLIGHT
+ #use appropriate all gather process group
+ ds_process_group = self.ds_process_group
+ rank_in_group = self.rank
+ world_size = self.dp_world_size
+ use_secondary_tensor = False
+ if self.zero_param_process_group and not forward:
+ ds_process_group = self.zero_param_process_group #intragroup
+ rank_in_group = self.rank_in_group
+ world_size = self.num_ranks_in_param_group
+
+ #pprint(dir(ds_process_group))
# ensure that each rank has params in same order. the allgather
# is done by flattening the parameter list into a single tensor that
# can be allgathered in a single call - this means that if each rank
@@ -926,41 +1051,122 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -
if len(params) == 1:
# have an opportunity to avoid some intermediate memory allocations
param, = params
+ buffer_size = math.ceil(param.ds_numel / world_size) * world_size
+ if not forward and param.ds_secondary_tensor is not None:
+ buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized
+
param_buffer = torch.empty(
- math.ceil(param.ds_numel / self.num_partitions) * self.num_partitions,
- dtype=param.dtype,
- device=get_accelerator().current_device_name(),
+ buffer_size,
+ dtype=param.dtype if not quant else torch.int8,
+ device=get_accelerator().current_device(),
requires_grad=False,
)
- handle = _dist_allgather_fn(param.ds_tensor.to(get_accelerator().current_device_name()), param_buffer,
- self.get_partition_dp_group(param))
- param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
- return AllGatherHandle(handle, param)
+ param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
+ if not quant:
+ handles = _dist_allgather_fn(
+ param_ds_tensor.to(get_accelerator().current_device()),
+ param_buffer,
+ ds_process_group,
+ )
+ param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
+ return AllGatherHandle(handles, param)
+ else:
+ quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor)
+ handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device()), param_buffer,
+ ds_process_group)
+
+ quant_scale_buffer = torch.empty(
+ scales.numel() * world_size,
+ dtype=torch.float32,
+ device=get_accelerator().current_device(),
+ requires_grad=False,
+ )
+ quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device()),
+ quant_scale_buffer, ds_process_group)
+ quant_info = QuantizationInfo()
+
+ quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(
+ param.device)
+ quant_info.backend = self.quantizer_module
+ quant_info.quant_handle = quant_handle
+ quant_info.scale_buffer = quant_scale_buffer
+ return AllGatherHandle(handle, param, quantization=quant_info)
+
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)
- flat_tensor = torch.empty(partition_sz * self.num_partitions,
- dtype=get_only_unique_item(p.dtype for p in params),
- device=get_accelerator().current_device_name(),
- requires_grad=False)
- partitions: List[Parameter] = []
- for i in range(self.num_partitions):
- partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
-
- instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
- out=partitions[self.get_partition_rank()])
- handle = _dist_allgather_fn(partitions[self.get_partition_rank()], flat_tensor,
- self.get_partition_dp_group(params[0]))
-
- return AllGatherCoalescedHandle(
- allgather_handle=handle,
- params=params,
- partitions=partitions,
- world_size=self.num_partitions,
- )
- def partition(param_list=None, hierarchy=0, has_been_updated=False):
+ if params[0].ds_secondary_tensor is not None and not forward:
+ partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
+
+ flat_tensor = torch.empty(partition_sz * world_size,
+ dtype=get_only_unique_item(p.dtype
+ for p in params) if not quant else torch.int8,
+ device=get_accelerator().current_device(),
+ requires_grad=False)
+ if not quant:
+ partitions: List[Parameter] = []
+ for i in range(world_size):
+ partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
+
+ if params[0].ds_secondary_tensor is not None and not forward:
+ use_secondary_tensor = True
+ instrument_w_nvtx(torch.cat)(
+ [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
+ out=partitions[rank_in_group])
+ else:
+ instrument_w_nvtx(
+ torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
+ out=partitions[rank_in_group])
+ handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
+ #Fix get_partition_dp_group(params[0]))
+
+ return AllGatherCoalescedHandle(
+ allgather_handle=handle,
+ params=params,
+ partitions=partitions,
+ world_size=world_size,
+ use_secondary_tensor=use_secondary_tensor,
+ forward=forward,
+ )
+ else:
+ if params[0].ds_secondary_tensor is not None and not forward:
+ use_secondary_tensor = True
+ quantized_param, scales = self.quantizer_module.quantize(
+ instrument_w_nvtx(torch.cat)(
+ [p.ds_secondary_tensor.to(get_accelerator().current_device()) for p in params]))
+ else:
+ quantized_param, scales = self.quantizer_module.quantize(
+ instrument_w_nvtx(
+ torch.cat)([p.ds_tensor.to(get_accelerator().current_device()) for p in params]))
+ handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
+ quant_info = QuantizationInfo()
+ quant_scale_buffer = torch.empty(
+ scales.numel() * world_size,
+ dtype=torch.float32,
+ device=get_accelerator().current_device(),
+ requires_grad=False,
+ )
+ quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
+ quant_info.quantized_param = flat_tensor
+ quant_info.backend = self.quantizer_module
+ quant_info.quant_handle = quant_handle
+ quant_info.scale_buffer = quant_scale_buffer
+ quant_info.partition_sz = partition_sz
+ quant_info.world_size = world_size
+ return AllGatherCoalescedHandle(
+ allgather_handle=handle,
+ params=params,
+ partitions=None,
+ world_size=world_size,
+ use_secondary_tensor=use_secondary_tensor,
+ forward=forward,
+ quantization=quant_info,
+ )
+
+ def partition(param_list=None, backward=False, hierarchy=0, has_been_updated=False):
cls = param
- print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}")
+ print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
+ force=False)
if param_list is None:
param_list = [cls]
self._partition(param_list, has_been_updated=has_been_updated)
@@ -1099,22 +1305,21 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
def _partition(self, param_list, force=False, has_been_updated=False):
for param in param_list:
- #print_rank_0(f"Before Partitioning Param {param.ds_id}")
- # self._param_status(param)
+ print_rank_0(f"Before Partitioning Param {param.ds_id} pri: {param.ds_tensor}", force=False)
+ if self.zero_param_process_group is not None:
+ self._partition_param_sec(param, has_been_updated=has_been_updated)
self._partition_param(param, has_been_updated=has_been_updated)
+
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
- #print_rank_0(f"After Partitioning Param {param.ds_id}")
- # self._param_status(param)
-
+ #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
@instrument_w_nvtx
def _partition_param(self, param, buffer=None, has_been_updated=False):
assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
-
global reuse_buffers
- #print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}")
+ print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
if param.ds_status is ZeroParamStatus.AVAILABLE:
print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False)
# if reuse_buffers and False:
@@ -1128,8 +1333,10 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
# if deepspeed.comm.get_rank():
# print(f"Releasing {param.data.numel()}")
- if param.ds_tensor is not None and not has_been_updated:
+ if param.ds_tensor is not None and not has_been_updated: ##param already partitioned
+
+ #print_rank_0(f"Param {param.ds_id} pri {param.ds_tensor.size()} loc? {param.ds_tensor.final_location}", force=True)
#param.data = param.ds_tensor.data
see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
@@ -1140,6 +1347,9 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False)
param.nvme_swapper.remove_partition_and_release_buffers([param])
+ print_rank_0(
+ f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme",
+ force=False)
return
@@ -1184,6 +1394,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
src_tensor = one_dim_param.narrow(0, start, partition_size)
param.ds_tensor.copy_(src_tensor)
+
#partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)
else:
@@ -1215,6 +1426,59 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}")
+ @instrument_w_nvtx
+ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
+ assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
+ global reuse_buffers
+ ##support for NVME secondary param offload
+ #print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True)
+ if param.ds_status is ZeroParamStatus.AVAILABLE:
+ if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned
+
+ return
+ #check padding
+ tensor_size = self._aligned_size(param)
+ partition_size = tensor_size // self.dp_world_size
+
+ secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group)
+ if param.ds_secondary_tensor is None:
+ final_location = None
+ secondary_partitioned_tensor = torch.empty(secondary_partition_size,
+ dtype=param.dtype,
+ device=self.remote_device)
+
+ if self.pin_memory:
+ secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory()
+
+ secondary_partitioned_tensor.requires_grad = False
+ param.ds_secondary_tensor = secondary_partitioned_tensor
+ param.ds_secondary_tensor.ds_numel = secondary_partition_size
+ param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE
+ param.ds_secondary_tensor.final_location = final_location
+
+ #use rank in group for secondary tensor
+ secondary_start = secondary_partition_size * self.rank_in_group
+
+ secondary_end = secondary_start + secondary_partition_size
+
+ one_dim_param = param.contiguous().view(-1)
+ start = partition_size * self.rank
+ end = start + partition_size
+ if start < param.ds_numel and end <= param.ds_numel:
+ if secondary_start < param.ds_numel and secondary_end <= param.ds_numel:
+ sec_src_tensor = one_dim_param.narrow(0, secondary_start, secondary_partition_size)
+ param.ds_secondary_tensor.copy_(sec_src_tensor)
+
+ else:
+ if start < param.ds_numel:
+ elements_to_copy = param.ds_numel - start
+ elements_to_copy_sec = elements_to_copy * param.ds_secondary_tensor_num_of_groups
+ param.ds_secondary_tensor.narrow(0, 0, elements_to_copy_sec).copy_(
+ one_dim_param.narrow(0, secondary_start, elements_to_copy_sec))
+
+ print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}",
+ force=False)
+
def _param_status(self, param):
if param.ds_tensor is not None:
print_rank_0(
diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py
index 8bf999458d8e1..5f2fdfeff8115 100644
--- a/deepspeed/runtime/zero/partitioned_param_coordinator.py
+++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py
@@ -12,6 +12,7 @@
from deepspeed.utils.logging import logger
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import *
+from deepspeed.runtime.zero.partitioned_param_profiler import PartitionedParameterProfiler
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.accelerator import get_accelerator
@@ -53,6 +54,14 @@ def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> Non
class PartitionedParameterCoordinator:
+ FORWARD_FETCH_SUBMIT = 'forward_fetch_submit'
+ FORWARD_FETCH_WAIT = 'forward_fetch_wait'
+ FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit'
+ BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit'
+ BACKWARD_FETCH_WAIT = 'backward_fetch_wait'
+ BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_wait'
+ FORWARD_ALL_GATHER = 'forward_all_gather'
+ BACKWARD_ALL_GATHER = 'backward_all_gather'
"""Handles partitioning and gathering of parameters."""
@dataclass
@@ -68,6 +77,7 @@ def __init__(
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
+ timers=None,
) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
@@ -107,6 +117,7 @@ def __init__(
self.__ongoing_fetch_events: Deque[get_accelerator().Event] = collections.deque()
# TODO. make this configurable via JSON
self.__max_ongoing_fetch_events: int = 2
+ self.__profiler = PartitionedParameterProfiler(timers)
"""Tracing and Tracking
TODO. consider performing trace before initializing PartitionedParameterCoordinator
@@ -207,22 +218,28 @@ def reset_step(self) -> None:
# Enable trace recording for next forward/backward pass
self.__trace_mode = ZeRoTraceMode.RECORD
+ else:
+ if self.__profiler is not None:
+ self.__profiler.log_events()
+
self.__param_queue = collections.deque(self.__param_order) # reset fetch queue
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
self.__step_id = 0
self.__n_available_params = 0
+ self.__profiler.reset_events()
def _dump_params(self, tag, sub_module, params, step_id=None):
if step_id is None:
step_id = self.__step_id
param_names = [debug_param2name_id(p) for p in params]
- print(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}')
+ print_rank_0(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}',
+ force=False)
def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
if step_id is None:
step_id = self.__step_id
- print(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}')
+ print_rank_0(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}', force=False)
"""Fetch and Release
Fetching, prefetching, and releasing parameters
@@ -230,7 +247,7 @@ def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
@instrument_w_nvtx
@torch.no_grad()
- def fetch_sub_module(self, current_submodule: Module) -> None:
+ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
"""This method does the following (in order):
1. kick off fetch for parameters in immediately required sub module
2. kick off fetch for next few parameters we will need later (prefetch)
@@ -246,19 +263,31 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
}))
params_to_fetch = frozenset(iter_params(current_submodule))
-
- # kick off all gather for params in the immediately required submodule
- if logger.isEnabledFor(logging.DEBUG):
- for param in params_to_fetch:
- debug_rank0(f"-fetch: {param.ds_summary()}")
- self.__all_gather_params(params_to_fetch)
-
+ fetch_numel = sum(
+ [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
+ if fetch_numel > 0:
+ event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
+ self._dump_param_ids(event_name, current_submodule.id,
+ [p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
+ self.__profiler.start_event(event_name)
+ # kick off all gather for params in the immediately required submodule
+ #for param in params_to_fetch:
+ if logger.isEnabledFor(logging.DEBUG):
+ for param in params_to_fetch:
+ debug_rank0(f"-fetch: {param.ds_summary()}")
+ self.__all_gather_params(params_to_fetch, forward)
+ self.__profiler.stop_event(event_name, fetch_numel)
+
+ wait_numel = 0
+ wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
+ self.__profiler.start_event(wait_event_name)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
+ wait_numel += param.partition_numel()
with get_accelerator().stream(self.__allgather_stream):
while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
self.__ongoing_fetch_events.popleft()
@@ -273,6 +302,7 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
+ self.__profiler.stop_event(wait_event_name, wait_numel)
# kick off parameter prefetches for upcoming modules
# don't prefetch if we dont have a completed model trace
@@ -332,10 +362,14 @@ def _is_currently_on_nvme(param):
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel
- if logger.isEnabledFor(logging.DEBUG):
- for param in params_to_prefetch:
- debug_rank0(f"-prefetch: {param.ds_summary()}")
- self.__all_gather_params(params_to_prefetch)
+ if numel_prefetching > 0:
+ event_name = __class__.FORWARD_PREFETCH_SUBMIT if forward else __class__.BACKWARD_PREFETCH_SUBMIT
+ self.__profiler.start_event(event_name)
+ if logger.isEnabledFor(logging.DEBUG):
+ for param in params_to_prefetch:
+ debug_rank0(f"-prefetch: {param.ds_summary()}")
+ self.__all_gather_params(params_to_prefetch, forward)
+ self.__profiler.stop_event(event_name, numel_prefetching)
if self.__prefetch_nvme:
self.__prefetch_nvme_param_partitions()
@@ -344,7 +378,7 @@ def _is_currently_on_nvme(param):
@instrument_w_nvtx
@torch.no_grad()
- def release_sub_module(self, submodule: Module) -> None:
+ def release_sub_module(self, submodule: Module, backward: bool) -> None:
"""release the parameters of a sub module, assuming they meet conditions to
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
@@ -352,7 +386,7 @@ def release_sub_module(self, submodule: Module) -> None:
for param in iter_params(submodule):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
- self.__release_param(param)
+ self.__release_param(param, backward)
@instrument_w_nvtx
@torch.no_grad()
@@ -365,25 +399,30 @@ def release_and_reset_all(self, module: Module) -> None:
# TODO. make this throw if if there are still active submodules. currently
# there's a hook execution issue
param.ds_active_sub_modules.clear()
- self.__release_param(param)
+ self.__release_param(param, backward=False)
for param in iter_params(module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"{param.ds_summary()} expected to be released")
@instrument_w_nvtx
- def __all_gather_params(self, params: Set[Parameter]) -> None:
+ def __all_gather_params(self, params: Set[Parameter], forward: bool) -> None:
"""for each partitioned parameter, kick off an async allgather and store
the work handle for the in flight parameters."""
partitioned_params = []
+ all_gather_numel = 0
for param in params:
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
partitioned_params.append(param)
- self.__n_available_params += param.ds_numel
+ all_gather_numel += param.ds_numel
if partitioned_params:
+ self.__n_available_params += all_gather_numel
with get_accelerator().stream(self.__allgather_stream):
- handle = partitioned_params[0].all_gather_coalesced(partitioned_params)
+ event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER
+ self.__profiler.start_event(event_name)
+ handle = partitioned_params[0].all_gather_coalesced(partitioned_params, forward)
+ self.__profiler.stop_event(event_name, all_gather_numel)
for param in partitioned_params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()
@@ -397,11 +436,11 @@ def __all_gather_params(self, params: Set[Parameter]) -> None:
swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params)
@instrument_w_nvtx
- def __release_param(self, param: Parameter) -> None:
+ def __release_param(self, param: Parameter, backward: bool) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
- param.partition()
+ param.partition(backward=backward)
self.__n_available_params -= param.ds_numel
@instrument_w_nvtx
diff --git a/deepspeed/runtime/zero/partitioned_param_profiler.py b/deepspeed/runtime/zero/partitioned_param_profiler.py
new file mode 100644
index 0000000000000..b4ea11f3b8363
--- /dev/null
+++ b/deepspeed/runtime/zero/partitioned_param_profiler.py
@@ -0,0 +1,63 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from dataclasses import dataclass
+from deepspeed.utils import log_dist
+
+
+class PartitionedParameterProfiler(object):
+
+ @dataclass
+ class EventCounter:
+ name: str
+ count: int
+ num_elem: int
+
+ def reset(self):
+ self.count = 0
+ self.num_elem = 0
+
+ def increment(self, numel):
+ self.count += 1
+ self.num_elem += numel
+
+ def __init__(self, timers):
+ self.timers = timers
+ self.event_counters = {}
+
+ def reset_events(self):
+ for event_ctr in self.event_counters.values():
+ event_ctr.reset()
+
+ def start_event(self, name):
+ if self.timers is None:
+ return
+
+ if name not in self.event_counters:
+ self.event_counters[name] = __class__.EventCounter(name=name, count=0, num_elem=0)
+ self.timers(name).start()
+
+ def stop_event(self, name, num_elem):
+ if self.timers is None:
+ return
+ assert name in self.event_counters, f'unknown event {name}'
+ self.event_counters[name].increment(num_elem)
+ self.timers(name).stop()
+
+ def _log_timers(self):
+ if self.timers is None:
+ return
+ self.timers.log(names=list(self.event_counters.keys()))
+
+ def _log_event_counters(self):
+ for event_ctr in self.event_counters.values():
+ log_dist(
+ f'{event_ctr.name}: count = {event_ctr.count}, numel = {event_ctr.num_elem}',
+ #f'{event_ctr.name}: time = {self._log_timers()},count = {event_ctr.count}, numel = {event_ctr.num_elem}',
+ ranks=[0])
+
+ def log_events(self):
+ self._log_event_counters()
+ self._log_timers()
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 098e7c22ee22f..20d6168b6b797 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -7,12 +7,14 @@
import gc
import collections
from typing import Deque, Dict, Tuple
+from deepspeed import comm as dist
+from deepspeed.utils import groups
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
-from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced
+from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
@@ -103,8 +105,10 @@ def __init__(self,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
elastic_checkpoint=False,
- aio_config=None):
-
+ aio_config=None,
+ all2all_process_group=None,
+ zero_hpz_partition_size=1,
+ zero_quantized_weights=False):
see_memory_usage("Stage 3 initialize beginning", force=True)
print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False)
@@ -146,6 +150,15 @@ def __init__(self,
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0
+ #num of ranks in a ZeRO param partitioning group
+ self.zero_hpz_partition_size = zero_hpz_partition_size
+
+ zpg = groups._get_zero_param_intra_parallel_group()
+ print_rank_0(f"ZeRO Stage 3 param partitioning group {self.zero_hpz_partition_size} {zpg}", force=False)
+ if self.zero_hpz_partition_size > 1 and zpg is None:
+ self._set_zero_group_parallelism()
+ zpg = groups._get_zero_param_intra_parallel_group()
+
self.parameter_offload = self.initialize_ds_offload(module=module,
timers=timers,
ds_config=ds_config,
@@ -156,7 +169,9 @@ def __init__(self,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
offload_param_config=offload_param_config,
- mpu=mpu)
+ mpu=mpu,
+ zpg=zpg,
+ zero_quantized_weights=zero_quantized_weights)
self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
@@ -184,10 +199,14 @@ def __init__(self,
self.timers = timers
+ self.all2all_process_group = all2all_process_group
+
self.reduce_scatter = reduce_scatter
self.dp_process_group = dp_process_group
+ self.all2all_process_group = all2all_process_group
+
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if mpu is None:
@@ -206,6 +225,9 @@ def __init__(self,
self.micro_step_id = 0
self.reduce_bucket_size = int(reduce_bucket_size)
+ if self.all2all_process_group is not None:
+ assert self.all2all_process_group is not None and self.reduce_scatter == True, "when enable all_to_all_reduce, reduce_scatter should also be enabled for data type checks."
+
if self.reduce_scatter:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"ZeRO-3 supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
@@ -314,6 +336,7 @@ def __init__(self,
self.averaged_gradients = {}
#creates backward hooks for gradient partitioning
+ ###Calls all gather param
self.create_reduce_and_remove_grad_hooks()
#exit(0)
@@ -348,6 +371,8 @@ def initialize_ds_offload(
model_persistence_threshold,
offload_param_config,
mpu,
+ zpg,
+ zero_quantized_weights,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
@@ -359,7 +384,9 @@ def initialize_ds_offload(
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
offload_param_config=offload_param_config,
- mpu=mpu)
+ mpu=mpu,
+ zero_param_parallel_group=zpg,
+ zero_quantized_weights=zero_quantized_weights)
def _get_trainable_parameter_groups(self):
param_groups = []
@@ -368,6 +395,15 @@ def _get_trainable_parameter_groups(self):
param_groups.append(trainable_params)
return param_groups
+ def _set_zero_group_parallelism(self):
+ groups._create_zero_param_parallel_group(self.zero_hpz_partition_size)
+
+ def invalidate_secondary_tensor(self):
+ for fpg in self.fp16_groups:
+ for param in fpg:
+ if param.ds_secondary_tensor is not None:
+ param.ds_secondary_tensor = None
+
def _setup_for_real_optimizer(self):
see_memory_usage("Before creating fp32 partitions", force=True)
self._create_fp32_partitions()
@@ -992,6 +1028,7 @@ def create_reduce_and_remove_grad_hooks(self):
for param in param_group:
if param.requires_grad:
#print_rank_0(f" Before all gather {param.device}, {param.shape}")
+ print_rank_0(f"Before all gather {param.device}, {param.shape}", force=False)
# The hook must be created in un-partitioned parameter
param.all_gather()
@@ -1148,7 +1185,13 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]
if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
full_grads_for_rank = [g.div(self.gradient_predivide_factor) for g in full_grads_for_rank]
- grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group)
+ local_world_size = get_accelerator().device_count()
+ global_world_size = dist.get_world_size()
+ num_nodes = global_world_size // local_world_size
+ if self.all2all_process_group is not None and num_nodes > 1:
+ grad_partitions_for_rank = all_to_all_quant_reduce(full_grads_for_rank, self.all2all_process_group)
+ else:
+ grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group)
if self.postscale_gradients and self.gradient_predivide_factor != 1.0 and self.gradient_predivide_factor != dist.get_world_size(
self.dp_process_group):
@@ -1254,6 +1297,8 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
# offload the gradient partition if applicable
if self.offload_optimizer:
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
+ offload_fp32_gradients = {}
+ offload_fp32_offsets = {}
if self.is_gradient_accumulation_boundary:
self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_buffer)
@@ -1773,6 +1818,8 @@ def _post_step(self, timer_names=set()):
if self.swap_optimizer:
self.optimizer_swapper.log_timers()
+ self.invalidate_secondary_tensor()
+
self.log_timers(timer_names)
see_memory_usage('After zero_optimizer step', force=False)
diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py
index dc58bb3c780f0..21c544c07a79b 100644
--- a/deepspeed/utils/groups.py
+++ b/deepspeed/utils/groups.py
@@ -26,20 +26,25 @@
"""
from deepspeed import comm as dist
-
from deepspeed.utils import log_dist
from deepspeed.utils.exceptions import DeprecatedException
-
+from deepspeed.accelerator import get_accelerator
# Expert parallel group that the current rank belongs to.
_EXPERT_PARALLEL_GROUP = {}
# Expert data parallel group that the current rank belongs to.
_EXPERT_DATA_PARALLEL_GROUP = {}
# dist world group needs to be cloned for some cases
_WORLD_GROUP = None
+# ZeRO parameter partitioning group that the current rank belongs to.
+_ZERO_PARAM_INTRA_PARALLEL_GROUP = None
# global object to maintain mpu object if passed by a Megatron client
mpu = None
# global object that stores tensor parallel world size for experts
expert_tensor_parallel_world_size = 1
+# All to All quantized graident communication groups
+_ALL_TO_ALL_GROUP = {}
+
+_DATA_PARALLEL_GROUP = None
# Deprecated groups initialize function.
@@ -316,10 +321,38 @@ def _clone_world_group():
return _WORLD_GROUP
+def _get_local_all_to_all_group():
+ assert dist.is_initialized(), 'dist is not initialized'
+ global _ALL_TO_ALL_GROUP
+ device_per_node = get_accelerator().device_count()
+ num_local = dist.get_world_size() // device_per_node
+ if num_local == 0 and dist.get_world_size() > 0:
+ assert dist.get_world_size() >= 1, 'num_gpus must >=1, cannot initialize All-To-All'
+ cur_rank = []
+ for i in range(dist.get_world_size()):
+ cur_rank.append(i)
+ _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=cur_rank)
+ elif num_local == 1:
+ assert dist.get_world_size(
+ ) == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All'
+ _ALL_TO_ALL_GROUP['local_0'] = dist.new_group(ranks=[i for i in range(device_per_node)])
+ else:
+ assert dist.get_world_size() > device_per_node, 'num_nodes<2 cannot initialize All-To-All'
+ for i in range(num_local):
+ local_rank = [j + device_per_node * i for j in range(device_per_node)]
+ _ALL_TO_ALL_GROUP[f"local_{i}"] = dist.new_group(ranks=local_rank)
+
+ for i in range(device_per_node):
+ cur_rank = []
+ for j in range(num_local):
+ cur_rank.append(i + j * device_per_node)
+ _ALL_TO_ALL_GROUP[f"global_{i}"] = dist.new_group(ranks=cur_rank)
+ return _ALL_TO_ALL_GROUP
+
+
def _get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
- assert dist.is_initialized(), \
- 'dist is not initialized'
+ assert dist.is_initialized(), 'dist is not initialized'
global mpu
if mpu is not None:
return mpu.get_data_parallel_group()
@@ -390,3 +423,63 @@ def _get_data_parallel_rank():
def _get_expert_model_parallel_world_size():
global expert_tensor_parallel_world_size
return expert_tensor_parallel_world_size
+
+
+def _create_zero_param_parallel_group(group_size):
+ """
+ Create parameter partitioning group within ZeRO data parallel groups.
+
+ Example - ZP + D parallel
+ world_size = 16
+ zero_hpz_partition_size = 2 # number of ranks with with replicated params (dual partitioning)
+ zero_param_intra_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - segmented (subgroup) with rep partition
+ data_parallel_group = [0,1,...,15] - all reduce is on ZeRO model
+ """
+ assert dist.is_initialized()
+ global _ZERO_PARAM_INTRA_PARALLEL_GROUP
+ # Only create group if it does not already exist
+ assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is None, \
+ 'ZeRO parameter intra parallel group is already initialized'
+
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+
+ zero_param_parallel_size_ = min(group_size, world_size)
+ _ensure_divisibility(world_size, zero_param_parallel_size_)
+
+ # Build the ZeRO param intra parallel groups.
+ for i in range(world_size // zero_param_parallel_size_):
+ ranks = range(i * zero_param_parallel_size_, (i + 1) * zero_param_parallel_size_)
+ group = dist.new_group(ranks)
+ if i == (rank // zero_param_parallel_size_):
+ _ZERO_PARAM_INTRA_PARALLEL_GROUP = group
+
+
+def _get_zero_param_intra_parallel_group():
+ """Get the ZeRO parameter partitioning intra parallel group the caller rank belongs to."""
+ #assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is not None, \
+ # 'ZeRO parameter partitioning group is not initialized'
+ #TODO: Add warning
+ return _ZERO_PARAM_INTRA_PARALLEL_GROUP
+
+
+def _zero_param_parallel_is_initialized():
+ """Check if ZeRO data parallel with parameter partititioning groups are initialized."""
+ ###TODO: assert that MPU is not set
+ if _ZERO_PARAM_INTRA_PARALLEL_GROUP is None and _DATA_PARALLEL_GROUP is None:
+ return False
+
+
+def _get_zero_param_intra_parallel_rank_in_mygroup():
+ """Return my rank for the ZeRO parameter inter parallel group."""
+ return dist.get_rank(group=_get_zero_param_intra_parallel_group())
+
+
+def _get_zero_param_intra_parallel_group_world_size():
+ """Return world size for the ZeRO parameter parallel group."""
+ return dist.get_world_size(group=_get_zero_param_intra_parallel_group())
+
+
+def _get_zero_param_intra_parallel_group_ranks():
+ """Return all ranks for the ZeRO parameter intra parallel group."""
+ return dist.get_all_ranks_from_group(group=_get_zero_param_intra_parallel_group())
diff --git a/docs/_tutorials/zeropp.md b/docs/_tutorials/zeropp.md
index 8a4a825b72077..f266b4d3df064 100644
--- a/docs/_tutorials/zeropp.md
+++ b/docs/_tutorials/zeropp.md
@@ -1,9 +1,9 @@
---
title: "ZeRO++"
-tags: training, ZeRO, communication-efficiency, large-model
+tags: training ZeRO communication-efficiency large-model
---
-ZeRO++ is a system of communication optimization strategies built on top of [ZeRO](https://www.microsoft.com/en-us/research/blog/zeropp) to offer unmatched efficiency for large model training regardless of the scale or cross-device bandwidth constraints. Read our [ZeRO++ blog](https://www.microsoft.com/en-us/research/blog/msr-zeropp-placeholder/) and [paper](https://www.microsoft.com/en-us/research/blog/arxiv-placehoder/) to learn more!
+ZeRO++ is a system of communication optimization strategies built on top of [ZeRO](/tutorials/zero/) to offer unmatched efficiency for large model training regardless of the scale or cross-device bandwidth constraints. Read our [ZeRO++ blog](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/) and [paper](https://arxiv.org/pdf/2306.10209.pdf) to learn more!
We recommend that you read the tutorials on [Getting Started](/getting-started/), [ZeRO](/tutorials/zero/) and [Megatron-DeepSpeed](/tutorials/megatron/) before stepping through this tutorial.
@@ -12,7 +12,7 @@ We recommend that you read the tutorials on [Getting Started](/getting-started/)
ZeRO++ consists of three key designs, namely quantized weights (*qwZ*), hiearchical partitioning ZeRO (*hpZ*), and quantized gradients (*qgZ*):
- *qwZ* applies block-based quantization to reduce ZeRO parameter all-gather communication volume by half from FP16 to INT8)
- *hpZ* eliminates inter-node backward parameter all-gather communication through data remapping and recomputation
- - *qwG* replaces gradients allreduce collective with a new communication efficient all-to-all based quantized gradient averaging.
+ - *qgZ* replaces gradients allreduce collective with a new communication efficient all-to-all based quantized gradient averaging.
Collectively, the three optimization reduces communication volume by 4x compared to ZeRO baseline. Each of the three components can be enabled independent of each other and collectively as a group as described in the next section.
@@ -26,7 +26,7 @@ There are no change needed to the user code. However, since ZeRO++ extends ZeRO
- zero_quantized_weights: Boolean indicating whether to use quantized zero weights (*qwZ*), default is false
- zero_hpz_partition_size: number of ranks in *hpZ* (secondary partition) group, default is 1 meaning no hpZ, ideal is number of ranks (gpus) per node
- - zero_quantized_gradients: Boolean indicating whether to use quantized zero gradients (*qwG*), default is false
+ - zero_quantized_gradients: Boolean indicating whether to use quantized zero gradients (*qgZ*), default is false
### DeepSpeed Configuration Changes
@@ -72,13 +72,13 @@ See more details on Megatron-DeepSpeed [tutorial](/tutorials/megatron/) examples
Here is a screenshots of the training log for both ZeRO baseline and ZeRO++:
ZeRO baseline
-
-
+
+
ZeRO++
-
-
+
+
Congratulations! You have completed the ZeRO++ tutorial.
diff --git a/docs/assets/images/zeropp/ZeRO-baseline.png b/docs/assets/images/zeropp/ZeRO-baseline.png
new file mode 100644
index 0000000000000..108c06a097aac
Binary files /dev/null and b/docs/assets/images/zeropp/ZeRO-baseline.png differ
diff --git a/docs/assets/images/zeropp/ZeROpp.png b/docs/assets/images/zeropp/ZeROpp.png
new file mode 100644
index 0000000000000..a72715cb699a3
Binary files /dev/null and b/docs/assets/images/zeropp/ZeROpp.png differ
diff --git a/docs/index.md b/docs/index.md
index b1835b47fb86a..4284f0d9e6670 100755
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,14 +7,11 @@ title: "Latest News"
---
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
+* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)
* [2023/04] π [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[δΈζ](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[ζ₯ζ¬θͺ](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]π
* [2023/03] [Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE](https://www.deepspeed.ai/2023/03/30/multi-modal.html)
* [2023/02] [Automatic Tensor Parallelism: Enables tensor parallelism by default without an injection policy](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/)
* [2022/12] [DeepSpeed Data Efficiency: A composable library that makes better use of data, increases training efficiency, and improves model quality](https://www.deepspeed.ai/2022/12/11/data-efficiency.html)
-* [2022/11] [Stable Diffusion Image Generation under 1 second w. DeepSpeed MII](https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/benchmark/txt2img)
-* [2022/10] [DeepSpeed-MII: instant speedup on 24,000+ open-source DL models with up to 40x cheaper inference](https://www.deepspeed.ai/2022/10/10/mii.html)
-* [2022/09] [ZeRO-Inference: Democratizing massive model inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html)
-* [2022/07] [Azure and DeepSpeed empower easy-to-use and high-performance model training](https://azure.microsoft.com/en-us/blog/azure-empowers-easytouse-highperformance-and-hyperscale-model-training-using-deepspeed/)
# Extreme Speed and Scale for DL Training and Inference
diff --git a/op_builder/quantizer.py b/op_builder/quantizer.py
index a64d1603d1e5e..5f651365187b3 100644
--- a/op_builder/quantizer.py
+++ b/op_builder/quantizer.py
@@ -23,6 +23,8 @@ def sources(self):
'csrc/quantization/fake_quantizer.cu',
'csrc/quantization/quantize.cu',
'csrc/quantization/dequantize.cu',
+ 'csrc/quantization/swizzled_quantize.cu',
+ 'csrc/quantization/quant_reduce.cu',
]
def include_paths(self):
diff --git a/tests/small_model_debugging/test_model.py b/tests/small_model_debugging/test_model.py
index 586106140d0b9..2706cde980d4c 100755
--- a/tests/small_model_debugging/test_model.py
+++ b/tests/small_model_debugging/test_model.py
@@ -16,15 +16,18 @@ class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False):
super(SimpleModel, self).__init__()
- self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
+ self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
if empty_grad:
- self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
+ self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim,
+ hidden_dim)]) #QuantizeLinear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, x, y):
hidden = x
- hidden = self.linear(hidden)
- return self.cross_entropy_loss(hidden, y)
+ hidden1 = self.linear(hidden)
+ hidden2 = self.linear(hidden1)
+ return self.cross_entropy_loss(hidden2, y)
def create_config_from_dict(tmpdir, config_dict):
@@ -48,9 +51,11 @@ def get_args(tmpdir, config_dict):
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--zero', type=int, default=0)
+ parser.add_argument('--zero_hpz_partition_size', type=int, default=1)
args = parser.parse_args() #args=''
config_dict["zero_optimization"]["stage"] = args.zero
+ config_dict["zero_optimization"]["zero_hpz_partition_size"] = args.zero_hpz_partition_size
print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
config_path = create_config_from_dict(tmpdir, config_dict)
@@ -68,7 +73,7 @@ def print0(msg):
torch.random.manual_seed(2222 + rank)
config_dict = {
- "train_batch_size": 8,
+ "train_batch_size": 256,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
@@ -78,17 +83,20 @@ def print0(msg):
},
"fp16": {
"enabled": True,
- "initial_scale_power": 15
+ "initial_scale_power": 8
},
"zero_optimization": {
"stage": 0,
"reduce_bucket_size": 20,
- "stage3_model_persistence_threshold": 10
+ "zero_hpz_partition_size": 1,
+ "reduce_scatter": True,
+ "zero_quantized_weights": False,
+ "zero_quantized_gradients": False
}
}
# "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
-hidden_dim = 32
+hidden_dim = 4 * 1024
model = SimpleModel(hidden_dim, empty_grad=False)
@@ -104,8 +112,9 @@ def print_params(tag, model):
print0("{} {}:{}".format(tag, n, p))
-data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device)
+data_loader = get_data_loader(model=model, total_samples=256, hidden_dim=hidden_dim, device=model.device)
#print_params('pre-train', model)
+
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if dist.get_rank() == 0:
@@ -113,4 +122,4 @@ def print_params(tag, model):
model.backward(loss)
model.step()
#print_params('step={}'.format(n), model)
- if n == 5: break
+ #if n == 5: break
diff --git a/tests/unit/runtime/zero/test_hpzero.py b/tests/unit/runtime/zero/test_hpzero.py
new file mode 100644
index 0000000000000..1d61d3c50a104
--- /dev/null
+++ b/tests/unit/runtime/zero/test_hpzero.py
@@ -0,0 +1,130 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import pytest
+import deepspeed.comm as dist
+from torch.nn import Module
+
+from unit.common import DistributedTest
+from unit.simple_model import random_dataloader
+
+import deepspeed
+
+from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
+
+import torch.nn as nn
+
+
+class NNModel(nn.Module):
+
+ def __init__(self, h_dim=1024, n_layers=2):
+ super(NNModel, self).__init__()
+ self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)])
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
+
+ def forward(self, x, y):
+ for layer in self.layers:
+ x = layer(x)
+ return self.cross_entropy_loss(x, y)
+
+
+def test_zero_hpz_partition_size_config():
+ config = DeepSpeedZeroConfig(**{"zero_hpz_partition_size": 4})
+ assert config.zero_hpz_partition_size == 4
+
+
+def _assert_no_secondary_tensor_group(model: Module) -> None:
+ for _, param in model.named_parameters():
+ assert param.ds_secondary_tensor is None
+ assert param.ds_zero_param_process_group is None
+
+
+#Large sweep along hidden dim, num_layers, and zpg of different sizes
+#Assert when zpg=1 that secondary group and tensors are invalid
+@pytest.mark.parametrize("h_dim", [1024, 2000])
+@pytest.mark.parametrize("n_layers", [8, 20])
+@pytest.mark.parametrize("zpg", [1, 2, 4])
+class TesthpZeroConfigSweep(DistributedTest):
+ world_size = 4
+
+ def test(self, h_dim: int, n_layers: int, zpg: int) -> None:
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_max_reuse_distance": 0,
+ "zero_hpz_partition_size": zpg,
+ "contiguous_gradients": True,
+ "overlap_comm": True,
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1.
+ }
+ },
+ "fp16": {
+ "enabled": True,
+ "loss_scale": 1.,
+ }
+ }
+
+ model = NNModel(h_dim, n_layers)
+ model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+ data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device)
+ dist.barrier()
+ if zpg == 1:
+ _assert_no_secondary_tensor_group(model)
+
+ for n, batch in enumerate(data_loader):
+ loss = model(batch[0], batch[1])
+ model.backward(loss)
+ model.step()
+
+
+def _assert_secondary_tensor_size(model: Module) -> None:
+ for _, param in model.named_parameters():
+ assert param.ds_secondary_tensor is not None
+ assert param.ds_secondary_tensor.size()[0] % param.ds_tensor.size()[0] == 0
+
+
+#Tests that secondary tensors are available and are of right sizes
+@pytest.mark.parametrize("h_dim", [1024, 4000])
+@pytest.mark.parametrize("n_layers", [8, 20])
+@pytest.mark.parametrize("zpg", [2, 4])
+class TestSecondaryTensorSize(DistributedTest):
+ world_size = 4
+
+ def test(self, h_dim: int, n_layers: int, zpg: int) -> None:
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_max_reuse_distance": 0,
+ "zero_hpz_partition_size": zpg,
+ "contiguous_gradients": True,
+ "overlap_comm": True,
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1.
+ }
+ },
+ "fp16": {
+ "enabled": True,
+ "loss_scale": 1.,
+ }
+ }
+
+ model = NNModel(h_dim, n_layers)
+ model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+ data_loader = random_dataloader(model=model, total_samples=4, hidden_dim=h_dim, device=model.device)
+ dist.barrier()
+
+ for n, batch in enumerate(data_loader):
+ loss = model(batch[0], batch[1])
+ model.backward(loss)
+ _assert_secondary_tensor_size(model)
+ if n == 0: break
diff --git a/tests/unit/runtime/zero/test_qgzero.py b/tests/unit/runtime/zero/test_qgzero.py
new file mode 100644
index 0000000000000..ccd0f166d305f
--- /dev/null
+++ b/tests/unit/runtime/zero/test_qgzero.py
@@ -0,0 +1,61 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import pytest
+import deepspeed.comm as dist
+from unit.common import DistributedTest
+from unit.simple_model import random_dataloader
+
+import deepspeed
+import torch.nn as nn
+
+
+class NNModel(nn.Module):
+
+ def __init__(self, h_dim=1024, n_layers=2):
+ super(NNModel, self).__init__()
+ self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)])
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
+
+ def forward(self, x, y):
+ for layer in self.layers:
+ x = layer(x)
+ return self.cross_entropy_loss(x, y)
+
+
+#Large sweep along hidden dim, num_layers of different sizes for qgZeRO.
+@pytest.mark.parametrize("h_dim", [1024, 2000])
+@pytest.mark.parametrize("n_layers", [8, 20])
+class TesthpZeroConfigSweep(DistributedTest):
+ world_size = 4
+
+ def test(self, h_dim: int, n_layers: int) -> None:
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "reduce_scatter": True,
+ "zero_quantized_gradients": True
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1.
+ }
+ },
+ "fp16": {
+ "enabled": True,
+ "loss_scale": 1.,
+ }
+ }
+
+ model = NNModel(h_dim, n_layers)
+ model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+ data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device)
+ dist.barrier()
+
+ for n, batch in enumerate(data_loader):
+ loss = model(batch[0], batch[1])
+ model.backward(loss)
+ model.step()
diff --git a/tests/unit/runtime/zero/test_qwzero.py b/tests/unit/runtime/zero/test_qwzero.py
new file mode 100644
index 0000000000000..71a0914e1a567
--- /dev/null
+++ b/tests/unit/runtime/zero/test_qwzero.py
@@ -0,0 +1,61 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import pytest
+import deepspeed.comm as dist
+from unit.common import DistributedTest
+from unit.simple_model import random_dataloader
+
+import deepspeed
+import torch.nn as nn
+
+
+class NNModel(nn.Module):
+
+ def __init__(self, h_dim=1024, n_layers=2):
+ super(NNModel, self).__init__()
+ self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)])
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
+
+ def forward(self, x, y):
+ for layer in self.layers:
+ x = layer(x)
+ return self.cross_entropy_loss(x, y)
+
+
+#Large sweep along hidden dim, num_layers of different sizes for qwZeRO.
+@pytest.mark.parametrize("h_dim", [1024, 2048])
+@pytest.mark.parametrize("n_layers", [8, 20])
+class TesthpZeroConfigSweep(DistributedTest):
+ world_size = 4
+
+ def test(self, h_dim: int, n_layers: int) -> None:
+ config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "zero_optimization": {
+ "stage": 3,
+ "reduce_scatter": True,
+ "zero_quantized_weights": True
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1.
+ }
+ },
+ "fp16": {
+ "enabled": True,
+ "loss_scale": 1.,
+ }
+ }
+
+ model = NNModel(h_dim, n_layers)
+ model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
+ data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device)
+ dist.barrier()
+
+ for n, batch in enumerate(data_loader):
+ loss = model(batch[0], batch[1])
+ model.backward(loss)
+ model.step()