Skip to content

Commit

Permalink
ZeRO++ clean release (#526)
Browse files Browse the repository at this point in the history
Co-authored-by: HeyangQin <[email protected]>
Co-authored-by: GuanhuaWang <[email protected]>
Co-authored-by: cmikeh2 <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
  • Loading branch information
9 people authored Jun 8, 2023
1 parent df42509 commit 1e7a41c
Show file tree
Hide file tree
Showing 24 changed files with 1,748 additions and 125 deletions.
26 changes: 26 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,32 @@ void launch_dequantize_kernel(T* dequant_data,
int total_elems,
cudaStream_t stream);

void launch_swizzled_quant(int8_t* q_data,
float* q_scales,
const __half* input_data,
int num_bits,
quantize::Type q_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream);

void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
cudaStream_t stream);

template <typename T>
void launch_fake_quantize_kernel(T* vals,
int total_count,
Expand Down
91 changes: 91 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,95 @@ at::Tensor dequantize(at::Tensor& quantized_data,
return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;

auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;

launch_swizzled_quant((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(__half*)input_vals.data_ptr(),
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());

return {output, scales};
}

std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({out_groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

std::vector<long int> sz(input_vals.sizes().begin(), input_vals.sizes().end());
const int gpu_per_node = 16; // depend on machine in_groups/out_groups;
sz[sz.size() - 1] = sz.back() / gpu_per_node; // num of GPU per nodes
const int elems_per_in_tensor = at::numel(input_vals) / gpu_per_node;
auto output = torch::empty(sz, output_options);

const int elems_per_in_group = elems_per_in_tensor / (in_groups / gpu_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
gpu_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / gpu_per_node,
elems_per_in_group,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
Expand All @@ -158,4 +247,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize<float>);
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
263 changes: 263 additions & 0 deletions csrc/quantization/quant_reduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <cstdio>
#include "dequantization_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "quantization_utils.h"
#include "reduction_utils.h"

using rop = reduce::ROpType;

/*
TODO(cmikeh2): Add implementation that better handles larger nodes. It would like make sense
to leverage some parallel reductions here to improve performance.
*/

template <int numBits, int numTensors, int totalChunks, quantize::Type quantType>
__global__ void __launch_bounds__(1024) dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
int num_tensors)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);

// NOTE(cmikeh2): This probably could be hardcoded to a larger number,
// but that means even stronger restrictions on the number of elements per group
// A performance analysis here might be beneficial
constexpr int mem_granularity = (numBits == 8) ? 8 : 4;
constexpr int elems_per_load = mem_granularity / sizeof(int8_t); // div by 1
constexpr int storage_values = 16 / sizeof(__half2);

const int block_offset = tb.group_index().x * elems_per_out_group;
const int elem_offset = tb.thread_index().x * elems_per_load;
const int base_offset = block_offset + elem_offset;
const int stride = tb.group_dim().x * elems_per_load;

__half2 local_buffer[totalChunks * storage_values];

quantize::GroupStats<quantType> stats;

#pragma unroll
for (int i = 0; i < totalChunks; i++) {
__half2* iteration_buffer = local_buffer + i * storage_values;

#pragma unroll
for (int j = 0; j < storage_values; j++) {
iteration_buffer[j] = reduce::init<rop::Add, __half2>();
}

const int iter_offset = i * stride + base_offset;
const int iter_scale_idx = iter_offset / elems_per_in_group;
bool do_loads = i * stride + elem_offset < elems_per_out_group;

if (numTensors > 0) {
#pragma unroll
for (int j = 0; j < numTensors; j++) {
if (do_loads) {
int8_t load_buffer[elems_per_load];

mem_access::load_global<mem_granularity>(
load_buffer, input_data + j * elems_per_in_tensor + iter_offset);

quantize::Params<quantType, numBits> params(
input_scales + j * groups_per_in_tensor, iter_scale_idx);

__half2 dequant_buffer[storage_values];
dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);

#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] =
reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
}
}
}
} else {
#pragma unroll 4
for (int j = 0; j < num_tensors; j++) {
if (do_loads) {
int8_t load_buffer[elems_per_load];

mem_access::load_global<mem_granularity>(
load_buffer, input_data + j * elems_per_in_tensor + iter_offset);

quantize::Params<quantType, numBits> params(
input_scales + j * groups_per_in_tensor, iter_scale_idx);

__half2 dequant_buffer[storage_values];
dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);

#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] =
reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
}
}
}
}

#pragma unroll
for (int j = 0; j < storage_values; j++) { stats.update(iteration_buffer[j]); }
}

auto params = stats.template get_params<numBits, 1024>(tb, warp);

if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }

#pragma unroll
for (int i = 0; i < totalChunks; i++) {
const int iter_offset = i * stride + base_offset;
if (i * stride + elem_offset < elems_per_out_group) {
int8_t local_output[elems_per_load];
quantize::_chunk<numBits, quantType>(
local_output, local_buffer + i * storage_values, params);
mem_access::store_global<mem_granularity>(reduced_data + iter_offset, local_output);
}
}
}

template <int Power>
int32_t pow2_round(int32_t raw_value)
{
return (((raw_value - 1) >> Power) + 1) << Power;
}

#define LAUNCH_DEQUANT_REDUCE(num_chunks) \
dequant_reduce<numBits, numTensors, num_chunks, quantType> \
<<<grid, block, 0, stream>>>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_tensors);

template <int numBits, int numTensors, quantize::Type quantType>
void launch_dequant_reduce_impl(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
int num_tensors,
cudaStream_t stream)
{
// This is a coincidence. This is derived by 8 halves per 16 bytes with 2-way packing for int4
constexpr int elems_per_thread = numBits;
const int one_step_threads =
next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread));
// TODO(cmikeh2): Tune this
const int threads = (one_step_threads < 1024) ? one_step_threads : 1024;

dim3 block(threads);
dim3 grid(out_groups);

const int elems_per_step = threads * elems_per_thread;
const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step;

const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw;

if (unroll == 1) {
// 0-4096 elems
LAUNCH_DEQUANT_REDUCE(1);
} else if (unroll == 2) {
// 4097-8192 etc...
LAUNCH_DEQUANT_REDUCE(2);
} else if (unroll == 3) {
LAUNCH_DEQUANT_REDUCE(3);
} else if (unroll == 4) {
LAUNCH_DEQUANT_REDUCE(4);
} else if (unroll == 6) {
LAUNCH_DEQUANT_REDUCE(6);
} else if (unroll == 8) {
LAUNCH_DEQUANT_REDUCE(8);
} else if (unroll == 10) {
LAUNCH_DEQUANT_REDUCE(10);
} else if (unroll == 12) {
// 48k limit
LAUNCH_DEQUANT_REDUCE(12);
} else {
assert(false);
}
}

#define LAUNCH_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
launch_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
out_groups, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_gpus, \
stream);

void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
cudaStream_t stream)
{
if (quant_type == quantize::Type::Symmetric) {
if (num_bits == 4) {
if (num_gpus == 8) {
LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric);
} else if (num_gpus == 16) {
LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric);
} else {
LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric);
}
} else if (num_bits == 8) {
if (num_gpus == 8) {
LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric);
} else if (num_gpus == 16) {
LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric);
} else {
LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric);
}
}
} else if (quant_type == quantize::Type::Asymmetric) {
if (num_bits == 4) {
if (num_gpus == 8) {
LAUNCH_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric);
} else if (num_gpus == 16) {
LAUNCH_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric);
} else {
LAUNCH_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric);
}
} else if (num_bits == 8) {
if (num_gpus == 8) {
LAUNCH_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric);
} else if (num_gpus == 16) {
LAUNCH_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric);
} else {
LAUNCH_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric);
}
}
}
}
Loading

0 comments on commit 1e7a41c

Please sign in to comment.