Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add triton fastertransformer backend support for deberta #725

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,13 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:BertLayerWeight>
$<TARGET_OBJECTS:BertTritonBackend>
$<TARGET_OBJECTS:BertWeight>
$<TARGET_OBJECTS:Deberta>
$<TARGET_OBJECTS:DebertaLayerWeight>
$<TARGET_OBJECTS:DebertaTritonBackend>
$<TARGET_OBJECTS:DebertaWeight>
$<TARGET_OBJECTS:DecoderCrossAttentionLayer>
$<TARGET_OBJECTS:DecoderSelfAttentionLayer>
$<TARGET_OBJECTS:DisentangledAttentionLayer>
$<TARGET_OBJECTS:DynamicDecodeLayer>
$<TARGET_OBJECTS:FfnLayer>
$<TARGET_OBJECTS:FusedAttentionLayer>
Expand Down Expand Up @@ -356,6 +361,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:T5EncoderTritonBackend>
$<TARGET_OBJECTS:TensorParallelDecoderCrossAttentionLayer>
$<TARGET_OBJECTS:TensorParallelDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:TensorParallelDisentangledAttentionLayer>
$<TARGET_OBJECTS:TensorParallelGeluFfnLayer>
$<TARGET_OBJECTS:TensorParallelSiluFfnLayer>
$<TARGET_OBJECTS:TensorParallelGptContextAttentionLayer>
Expand Down Expand Up @@ -384,6 +390,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:cutlass_preprocessors>
$<TARGET_OBJECTS:decoder_masked_multihead_attention>
$<TARGET_OBJECTS:decoding_kernels>
$<TARGET_OBJECTS:disentangled_attention_kernels>
$<TARGET_OBJECTS:fpA_intB_gemm>
$<TARGET_OBJECTS:gen_relative_pos_bias>
$<TARGET_OBJECTS:gpt_kernels>
Expand Down
5 changes: 5 additions & 0 deletions src/fastertransformer/models/deberta/Deberta.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class Deberta: public BaseLayer {
const std::vector<Tensor>* input_tensors,
const DebertaWeight<T>* deberta_weights);
void forward(TensorMap* output_tensors, TensorMap* input_tensors, const DebertaWeight<T>* deberta_weights);

inline size_t getHiddenUnits()
{
return hidden_units_;
}
};

} // namespace fastertransformer
2 changes: 1 addition & 1 deletion src/fastertransformer/models/deberta/DebertaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ void DebertaWeight<T>::loadModel(std::string dir_path)

for (uint l = 0; l < num_layer_; l++) {
if (isValidLayerParallelId(l)) {
deberta_layer_weights[l].loadModel(dir_path + "model.encoder.layer." + std::to_string(l) + ".",
deberta_layer_weights[l].loadModel(dir_path + "/model.encoder.layer." + std::to_string(l) + ".",
model_file_type);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/fastertransformer/triton_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ if (ENABLE_FP8)
add_subdirectory(multi_gpu_gpt_fp8)
endif()
add_subdirectory(bert)
add_subdirectory(deberta)
25 changes: 25 additions & 0 deletions src/fastertransformer/triton_backend/deberta/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.8)

set(deberta_triton_backend_files
DebertaTritonModel.cc
DebertaTritonModelInstance.cc
)

add_library(DebertaTritonBackend STATIC ${deberta_triton_backend_files})
set_property(TARGET DebertaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(DebertaTritonBackend PRIVATE Deberta TransformerTritonBackend -lcublasLt)
target_compile_features(DebertaTritonBackend PRIVATE cxx_std_14)
225 changes: 225 additions & 0 deletions src/fastertransformer/triton_backend/deberta/DebertaTritonModel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "3rdparty/INIReader.h"

#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h"
#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.h"

namespace ft = fastertransformer;

template<typename T>
DebertaTritonModel<T>::DebertaTritonModel(size_t tensor_para_size,
size_t pipeline_para_size,
bool enable_custom_all_reduce,
std::string model_dir,
bool is_sparse,
bool is_remove_padding):
tensor_para_size_(tensor_para_size),
pipeline_para_size_(pipeline_para_size),
shared_weights_(std::vector<std::shared_ptr<ft::DebertaWeight<T>>>(ft::getDeviceCount())),
enable_custom_all_reduce_(enable_custom_all_reduce),
model_dir_(model_dir),
is_sparse_(is_sparse),
is_remove_padding_(is_remove_padding)
{
FT_CHECK_WITH_INFO(is_sparse == false, "still not support sparse in deberta backend");

INIReader reader = INIReader(model_dir + "/config.ini");
if (reader.ParseError() < 0) {
std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini"
<< "'\n";
ft::FT_CHECK(false);
}

/* Deberta base Configuration File Example
[deberta]
model_name = deberta
hidden_size = 1024
num_layer = 24
head_num = 16
size_per_head = 64
activation_type = gelu
inter_size = 4096
vocab_size = 128100
max_relative_positions = 512
relative_position_buckets = 256
weight_data_type = fp32
*/

model_name_ = reader.Get("deberta", "model_name");
head_num_ = reader.GetInteger("deberta", "head_num");
size_per_head_ = reader.GetInteger("deberta", "size_per_head");
inter_size_ = reader.GetInteger("deberta", "inter_size");
vocab_size_ = reader.GetInteger("deberta", "vocab_size");
num_layer_ = reader.GetInteger("deberta", "num_layer");
max_relative_positions_ = reader.GetInteger("deberta", "max_relative_positions");
relative_position_buckets_ = reader.GetInteger("deberta", "relative_position_buckets");
layernorm_type_ = ft::getLayerNormType("post_layernorm");
activation_type_ = ft::getActivationType(reader.Get("deberta", "activation_type", "Gelu"));
q_scaling_ = reader.GetFloat("deberta", "q_scaling", sqrtf(3.0f));
}

template<typename T>
std::unique_ptr<AbstractTransformerModelInstance>
DebertaTritonModel<T>::createModelInstance(int device_id,
int rank,
cudaStream_t stream,
std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params,
std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm)
{
ft::check_cuda_error(cudaSetDevice(device_id));
const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_);
const int tensor_para_rank = rank % tensor_para_size_;
const int pipeline_para_rank = rank / tensor_para_size_;

std::unique_ptr<ft::Allocator<ft::AllocatorType::CUDA>> allocator(
new ft::Allocator<ft::AllocatorType::CUDA>(device_id));

allocator->setStream(stream);

cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;

cublasCreate(&cublas_handle);
cublasLtCreate(&cublaslt_handle);
cublasSetStream(cublas_handle, stream);

std::unique_ptr<ft::cublasAlgoMap> cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in"));
std::unique_ptr<std::mutex> cublas_wrapper_mutex(new std::mutex());
std::unique_ptr<ft::cublasMMWrapper> cublas_wrapper(new ft::cublasMMWrapper(
cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get()));

std::unique_ptr<cudaDeviceProp> cuda_device_prop_ptr(new cudaDeviceProp);
ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id));

if (std::is_same<T, half>::value) {
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
}
#ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) {
cublas_wrapper->setBF16GemmConfig();
}
#endif
else if (std::is_same<T, float>::value) {
cublas_wrapper->setFP32GemmConfig();
}

ft::NcclParam tensor_para = nccl_params.first[comms_rank];
ft::NcclParam pipeline_para = nccl_params.second[comms_rank];

auto deberta =
std::make_unique<ft::Deberta<T>>(ft::Deberta<T>(0, // max_batch_size, FT will adjust the buffer automatically.
0, // max_seq_len, FT will adjust the buffer automatically.
head_num_,
size_per_head_,
max_relative_positions_,
relative_position_buckets_,
inter_size_,
num_layer_,
q_scaling_,
stream,
cublas_wrapper.get(),
allocator.get(),
false,
is_sparse_,
activation_type_,
layernorm_type_,
tensor_para,
pipeline_para,
custom_all_reduce_comm,
enable_custom_all_reduce_));

#ifdef SPARSITY_ENABLED
if (is_sparse_) {
for (int i = 0; i < num_layer_; ++i) {
shared_weights_[device_id]->deberta_layer_weights[i].compress_weights(*(cublas_wrapper.get()),
head_num_ * size_per_head_);
}
}
#endif

return std::unique_ptr<DebertaTritonModelInstance<T>>(new DebertaTritonModelInstance<T>(std::move(deberta),
shared_weights_[device_id],
std::move(allocator),
std::move(cublas_algo_map),
std::move(cublas_wrapper_mutex),
std::move(cublas_wrapper),
std::move(cuda_device_prop_ptr)));
}

template<typename T>
void DebertaTritonModel<T>::createSharedWeights(int device_id, int rank)
{
ft::check_cuda_error(cudaSetDevice(device_id));
const int tensor_para_rank = rank % tensor_para_size_;
const int pipeline_para_rank = rank / tensor_para_size_;
shared_weights_[device_id] = std::make_shared<ft::DebertaWeight<T>>(head_num_ * size_per_head_,
inter_size_,
max_relative_positions_,
relative_position_buckets_,
vocab_size_,
num_layer_,
tensor_para_size_,
tensor_para_rank,
pipeline_para_size_,
pipeline_para_rank);

shared_weights_[device_id]->loadModel(model_dir_);
return;
}

template<typename T>
std::string DebertaTritonModel<T>::toString()
{
std::stringstream ss;
ss << "Model: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nhead_num: " << head_num_
<< "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_
<< "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_
<< "\nmax_relative_positions: " << max_relative_positions_ << "\nrelative_position_buckets: " << relative_position_buckets_
<< "\nq_scaling: " << q_scaling_ << "\nis_remove_padding: " << is_remove_padding_
<< "\nis_sparse: " << is_sparse_ << "\nactivation_type: " << static_cast<int>(activation_type_)
<< "\nlayernorm_type: " << static_cast<int>(layernorm_type_) << "\nvocab_size: " << vocab_size_
<< "\nenable_custom_all_reduce:" << enable_custom_all_reduce_ << std::endl;

return ss.str();
}

template<typename T>
void DebertaTritonModel<T>::createCustomComms(
std::vector<std::shared_ptr<ft::AbstractCustomComm>>* custom_all_reduce_comms, int world_size)
{
using commDataType = typename ft::CustomARCommTypeConverter<T>::Type;
ft::initCustomAllReduceComm<commDataType>(custom_all_reduce_comms, enable_custom_all_reduce_, world_size);
}

template<typename T>
int DebertaTritonModel<T>::getTensorParaSize()
{
return tensor_para_size_;
}

template<typename T>
int DebertaTritonModel<T>::getPipelineParaSize()
{
return pipeline_para_size_;
}

template struct DebertaTritonModel<float>;
template struct DebertaTritonModel<half>;
#ifdef ENABLE_BF16
template struct DebertaTritonModel<__nv_bfloat16>;
#endif
70 changes: 70 additions & 0 deletions src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "src/fastertransformer/models/deberta/Deberta.h"
#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp"

namespace ft = fastertransformer;

template<typename T>
struct DebertaTritonModel: public AbstractTransformerModel {
DebertaTritonModel(size_t tensor_para_size,
size_t pipeline_para_size,
bool enable_custom_all_reduce,
std::string model_dir,
bool is_sparse,
bool is_remove_padding);

virtual std::unique_ptr<AbstractTransformerModelInstance>
createModelInstance(int deviceId,
int rank,
cudaStream_t stream,
std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params,
std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr) override;

virtual void createSharedWeights(int deviceId, int rank) override;

virtual void createCustomComms(std::vector<std::shared_ptr<ft::AbstractCustomComm>>* custom_all_reduce_comms,
int world_size) override;

virtual std::string toString() override;
virtual int getTensorParaSize() override;
virtual int getPipelineParaSize() override;

private:
size_t head_num_;
size_t size_per_head_;
size_t inter_size_;
size_t num_layer_;
size_t vocab_size_;
size_t tensor_para_size_;
size_t pipeline_para_size_;
size_t max_relative_positions_;
size_t relative_position_buckets_;

float q_scaling_;
bool is_remove_padding_;
bool is_sparse_;
ft::ActivationType activation_type_;
ft::LayerNormType layernorm_type_;

std::string model_name_;
std::string model_dir_;
bool enable_custom_all_reduce_ = 0;
std::vector<std::shared_ptr<ft::DebertaWeight<T>>> shared_weights_;
};
Loading