diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index e60fe3d67b5836..88fe29b0ff0a12 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -102,7 +102,8 @@ pass_library(add_support_int8_pass inference) pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(mixed_precision_configure_pass inference) -pass_library(desne_to_sparse_pass inference) +pass_library(desne_fc_to_sparse_pass inference) +pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc new file mode 100644 index 00000000000000..c366d74d532612 --- /dev/null +++ b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +DenseMultiheadMatmulToSparsePass::DenseMultiheadMatmulToSparsePass() { + AddOpCompat(OpCompat("multihead_matmul")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddInput("BiasQK") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + +void DenseMultiheadMatmulToSparsePass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + std::string name_scope = "dense_multihead_matmul_to_sparse_pass"; + FusePassBase::Init(name_scope, graph); + GraphPatternDetector gpd; + + patterns::MultiheadMatmul multihead_matmul_pattern( + gpd.mutable_pattern(), "dense_multihead_matmul_replace_pass"); + multihead_matmul_pattern(); + int found_multihead_matmul_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Replace dense multihead matmul with sparse multihead matmul."; + + /* if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + }*/ + + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, + multihead_matmul_weights, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, + multihead_matmul_pattern); + + auto *multihead_matmul_op = multihead_matmul->Op(); + auto w_name = multihead_matmul_op->Input("W")[0]; + // recognize sparse op by name + if (w_name.find("sparse_2_4") != w_name.npos) { + // fake op + OpDesc desc(multihead_matmul_op->Block()); + desc.SetType("sparse_multihead_matmul"); + desc.SetInput("Input", {multihead_matmul_input->Name()}); + desc.SetInput("W", {multihead_matmul_weights->Name()}); + desc.SetInput("Bias", {multihead_matmul_bias->Name()}); + desc.SetInput("BiasQK", {multihead_matmul_biasqk->Name()}); + desc.SetOutput("Out", {multihead_matmul_out->Name()}); + + // copy all attr + desc.SetAttr("alpha", multihead_matmul_op->GetAttr("alpha")); + desc.SetAttr("head_number", multihead_matmul_op->GetAttr("head_number")); + if (multihead_matmul_op->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", + multihead_matmul_op->GetAttr("Input_scale")); + } + if (multihead_matmul_op->HasAttr("fc_out_threshold")) { + desc.SetAttr("fc_out_threshold", + multihead_matmul_op->GetAttr("fc_out_threshold")); + } + if (multihead_matmul_op->HasAttr("qkv2context_plugin_int8")) { + desc.SetAttr("qkv2context_plugin_int8", + multihead_matmul_op->GetAttr("qkv2context_plugin_int8")); + } + if (multihead_matmul_op->HasAttr("dp_probs")) { + desc.SetAttr("dp_probs", multihead_matmul_op->GetAttr("dp_probs")); + } + if (multihead_matmul_op->HasAttr("out_threshold")) { + desc.SetAttr("out_threshold", + multihead_matmul_op->GetAttr("out_threshold")); + } + desc.Flush(); + GraphSafeRemoveNodes(g, {multihead_matmul}); + auto sparse_multihead_matmul_node = g->CreateOpNode(&desc); + + IR_NODE_LINK_TO(multihead_matmul_input, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_weights, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_bias, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_biasqk, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(sparse_multihead_matmul_node, multihead_matmul_out); + found_multihead_matmul_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_multihead_matmul_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(dense_multihead_matmul_to_sparse_pass, + paddle::framework::ir::DenseMultiheadMatmulToSparsePass); diff --git a/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h new file mode 100644 index 00000000000000..a1e12857c659be --- /dev/null +++ b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Replace dense multihead_matmul op with sparse multihead_matmul op + */ +class Graph; + +class DenseMultiheadMatmulToSparsePass : public FusePassBase { + public: + DenseMultiheadMatmulToSparsePass(); + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + const std::string name_scope_{"dense_multihead_matmul_to_sparse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/desne_to_sparse_pass.cc b/paddle/fluid/framework/ir/desne_fc_to_sparse_pass.cc similarity index 90% rename from paddle/fluid/framework/ir/desne_to_sparse_pass.cc rename to paddle/fluid/framework/ir/desne_fc_to_sparse_pass.cc index d2e0823c3323ae..c1027fdfb6f241 100644 --- a/paddle/fluid/framework/ir/desne_to_sparse_pass.cc +++ b/paddle/fluid/framework/ir/desne_fc_to_sparse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/desne_to_sparse_pass.h" +#include "paddle/fluid/framework/ir/desne_fc_to_sparse_pass.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -20,7 +20,7 @@ namespace paddle { namespace framework { namespace ir { -ReplaceDenseWithSparsePass::ReplaceDenseWithSparsePass() { +DenseFCToSparsePass::DenseFCToSparsePass() { AddOpCompat(OpCompat("fc")) .AddInput("Input") .IsTensor() @@ -36,16 +36,16 @@ ReplaceDenseWithSparsePass::ReplaceDenseWithSparsePass() { .End(); } -void ReplaceDenseWithSparsePass::ApplyImpl(Graph *graph) const { +void DenseFCToSparsePass::ApplyImpl(Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "desne_to_sparse_pass"; + std::string name_scope = "desne_fc_to_sparse_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; patterns::DenseFC dense_fc_pattern(gpd.mutable_pattern(), - "dense_replace_pass"); + "dense_fc_replace_pass"); dense_fc_pattern(); int found_dense_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, @@ -115,5 +115,5 @@ void ReplaceDenseWithSparsePass::ApplyImpl(Graph *graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(desne_to_sparse_pass, - paddle::framework::ir::ReplaceDenseWithSparsePass); +REGISTER_PASS(desne_fc_to_sparse_pass, + paddle::framework::ir::DenseFCToSparsePass); diff --git a/paddle/fluid/framework/ir/desne_to_sparse_pass.h b/paddle/fluid/framework/ir/desne_fc_to_sparse_pass.h similarity index 88% rename from paddle/fluid/framework/ir/desne_to_sparse_pass.h rename to paddle/fluid/framework/ir/desne_fc_to_sparse_pass.h index 33e278d778fbbb..31ecf8f95dcc3f 100644 --- a/paddle/fluid/framework/ir/desne_to_sparse_pass.h +++ b/paddle/fluid/framework/ir/desne_fc_to_sparse_pass.h @@ -30,14 +30,14 @@ namespace ir { */ class Graph; -class ReplaceDenseWithSparsePass : public FusePassBase { +class DenseFCToSparsePass : public FusePassBase { public: - ReplaceDenseWithSparsePass(); + DenseFCToSparsePass(); protected: void ApplyImpl(ir::Graph* graph) const override; - const std::string name_scope_{"desne_to_sparse_pass"}; + const std::string name_scope_{"desne_fc_to_sparse_pass"}; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 410c899835a7ed..0aa7ae31847c70 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3423,6 +3423,44 @@ PDNode *patterns::DenseFC::operator()() { return fc_out; } +PDNode *patterns::MultiheadMatmul::operator()() { + auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr()) + ->assert_is_op("multihead_matmul"); + // Input + auto *multihead_matmul_input = + pattern->NewNode(multihead_matmul_input_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Input"); + // Filter + auto *multihead_matmul_weights = + pattern->NewNode(multihead_matmul_weights_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "W"); + // Bias + auto *multihead_matmul_bias = + pattern->NewNode(multihead_matmul_bias_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Bias"); + // BiasQK + auto *multihead_matmul_biasqk = + pattern->NewNode(multihead_matmul_biasqk_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "BiasQK"); + // Output + auto *multihead_matmul_out = + pattern->NewNode(multihead_matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("multihead_matmul", "Out") + ->assert_is_only_output_of_op("multihead_matmul"); + + multihead_matmul + ->LinksFrom({multihead_matmul_input, multihead_matmul_weights, + multihead_matmul_bias, multihead_matmul_biasqk}) + .LinksTo({multihead_matmul_out}); + + return multihead_matmul_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 22a8c07c136942..aaaee721a9d325 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1971,6 +1971,24 @@ struct DenseFC : public PatternBase { PATTERN_DECL_NODE(fc_bias); }; +// +// \brief Pattern looking for multihead matmul fc. +// +struct MultiheadMatmul : public PatternBase { + MultiheadMatmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(multihead_matmul); + PATTERN_DECL_NODE(multihead_matmul_out); + PATTERN_DECL_NODE(multihead_matmul_input); + PATTERN_DECL_NODE(multihead_matmul_weights); + PATTERN_DECL_NODE(multihead_matmul_bias); + PATTERN_DECL_NODE(multihead_matmul_biasqk); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index eb02666a5b969e..5b14cf372f1d3f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1810,6 +1810,7 @@ USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(remove_padding) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) +USE_TRT_CONVERTER(sparse_multihead_matmul) #endif #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 6589ad86f02b88..8345c546debab0 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -114,9 +114,10 @@ const std::vector kTRTSubgraphPasses({ "remove_padding_recover_padding_pass", // "delete_remove_padding_recover_padding_pass", // // "yolo_box_fuse_pass", // - "desne_to_sparse_pass", - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "desne_fc_to_sparse_pass", // + "dense_multihead_matmul_to_sparse_pass" // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index c713e3a66ac718..3d01db9a1572df 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -6,6 +6,6 @@ else() endif() nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto device_context boost) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) -nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) +nv_test(test_tensorrt_engine SRCS test_engine.cc test_dynamic_engine.cc DEPS dynload_cuda tensorrt_engine tensorrt_plugin) add_subdirectory(plugin) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 3f3184eb95d18f..f969f05ab0c3e1 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -11,7 +11,7 @@ preln_skip_layernorm.cc strided_slice_op.cc roll_op.cc transformer_input_convert recover_padding_op.cc) if (CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) - list(APPEND CONVERT_FILES sparse_fc_op.cc) + list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) endif() nv_library(tensorrt_converter diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc index 01c3fad7a3cbec..d32375c6584abc 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc @@ -254,15 +254,27 @@ class SparseFcOpConverter : public OpConverter { float* bias_data = nullptr; int bias_num = 0; + void* b_data = nullptr; if (with_bias) { auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_t = b_v->GetMutable(); bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); bias_num = b_t->numel(); + + half* half_bias_data = nullptr; + if (with_fp16) { + half_bias_data = new half[bias_num]; + for (int i = 0; i < bias_num; i++) { + half_bias_data[i] = static_cast(bias_data[i]); + } + b_data = static_cast(half_bias_data); + } else { + b_data = static_cast(bias_data); + } } - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, - static_cast(bias_data), - static_cast(bias_num)}; + TensorRTEngine::Weight bias{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + b_data, static_cast(bias_num)}; // Running the TRT Static Shape mode: x_num_col_dims-1 if (!engine_->with_dynamic_shape()) { diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc new file mode 100644 index 00000000000000..ab01f040fe02de --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -0,0 +1,448 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SparseMultiheadMatMulOpConverter : public OpConverter { + public: + plugin::SpmmPluginDynamic* new_spmm_plugin(TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias, + nvinfer1::DataType type, + int outdim) { + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + return new plugin::SpmmPluginDynamic("CustomSpmmPluginDynamic", type, + outdim, weight->get(), bias->get(), + act); + } + + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding " + "tensorrt " + "network structure"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("Input").front()); + + // fc weights and fc bias + auto weight_name = op_desc.Input("W").front(); + auto bias_name = op_desc.Input("Bias").front(); + + auto* weight_v = scope.FindVar(weight_name); + auto* weight_t = weight_v->GetMutable(); + + auto* bias_v = scope.FindVar(bias_name); + auto* bias_t = bias_v->GetMutable(); + + float* weight_data = nullptr; + bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); + float in_scale = 0.; + + if (op_desc.HasAttr("Input_scale")) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + engine_->SetTensorDynamicRange(input, in_scale); + } + weight_data = engine_->GetWeightCPUData(weight_name, weight_t); + + float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_t->numel()); + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + + // (hidden_in, 3, hidden_out) + auto weight_dims = weight_t->dims(); + + int hidden_in = weight_dims[0]; // channels_in + int three = weight_dims[1]; // channels_out + int hidden_out = weight_dims[2]; // channels_out + int m = hidden_in; + int n = three * hidden_out; + auto tranpose_weight = [](const float* src, float* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } + }; + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + + int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number")); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + nvinfer1::ILayer* layer = nullptr; + auto output_name = op_desc.Output("Out")[0]; + + if (engine_->with_dynamic_shape()) { + if (engine_->use_oss()) { + if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { + PADDLE_THROW(platform::errors::Fatal( + "use use_oss must be int8 or half, not float32.")); + } + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + if (engine_->with_interleaved()) { + VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; + if (!op_desc.HasAttr("Input_scale")) { + PADDLE_THROW( + platform::errors::Fatal("use with_interleaved must be int8.")); + } + nvinfer1::ILayer* fc_layer = nullptr; + float dp_probs = 1.0 / 127.0; + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, + nv_ksize, weight, bias); + fc_layer->setName( + ("Multihead: Convolution/FullyConnected: (Output: " + + output_name + ")") + .c_str()); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out_threshold in multihead layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "3"); + assert(creator != nullptr); + std::vector fields{ + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, + 1}}; + if (qkv2context_plugin_int8) { + fields.push_back({"dq_probs", &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, 1}); + } + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast(malloc( + sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + if (engine_->Has("ernie_pos_name")) { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->Get("ernie_pos_name"))); + } else { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network() + ->getInput(2) + ->getName())); // cu_seqlens, eval_placeholder_2 + } + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f); + auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, + *const_cast(max_seqlen_tensor)); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + shuffle_layer->setName( + ("Multihead: Shuffle: (Output: " + output_name + ")").c_str()); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } else { + int head_size = hidden_out / head_number; + // [3, head_number, head_size, hidden_in] -> [head_number, 3, + // head_size, + // hidden_in] + auto transpose_weight_v2 = [](const float* src, float* dst, int three, + int head_number, int head_size, + int hidden_in) { + const int HH = head_size * hidden_in; + for (int i = 0; i < three; ++i) { + for (int n = 0; n < head_number; ++n) { + for (int hh = 0; hh < HH; ++hh) { + dst[n * three * HH + i * HH + hh] = + src[i * head_number * HH + n * HH + hh]; + } + } + } + }; + // [3, head_number, head_size] -> [head_number, 3, head_size] + auto transpose_bias_v2 = [](const float* src, float* dst, int N, + int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + transpose_weight_v2(weight_data_tmp.data(), weight_data, three, + head_number, head_size, hidden_in); + + std::vector bias_data_tmp; + bias_data_tmp.reserve(bias_t->numel()); + memcpy(bias_data_tmp.data(), bias_data, + bias_t->numel() * sizeof(float)); + transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, + head_size); + + nvinfer1::ILayer* fc_layer = nullptr; + float dp_probs = 1.0 / 127.0; + if (op_desc.HasAttr("Input_scale")) { + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, + nv_ksize, weight, bias); + } else { + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, + weight, bias); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers " + "in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } + } + + auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); + + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "2"); + assert(creator != nullptr); + int type = static_cast(nvinfer1::DataType::kHALF); + if (qkv2context_plugin_int8 && + (engine_->precision() == AnalysisConfig::Precision::kInt8)) { + type = static_cast(nvinfer1::DataType::kINT8); + } + bool has_mask = true; + int var_seqlen = 1; + std::vector fields{ + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, + {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, + {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, + 1}}; + if (qkv2context_plugin_int8) { + fields.push_back({"dq_probs", &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, 1}); + } + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast(malloc( + sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + plugin_inputs.emplace_back(mask_tensor); + if (engine_->Has("ernie_pos_name")) { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->Get("ernie_pos_name"))); + } else { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network() + ->getInput(2) + ->getName())); // cu_seqlens, eval_placeholder_2 + } + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, + *const_cast(max_seqlen_tensor)); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } + } else { + PADDLE_ENFORCE_EQ( + input->getDimensions().nbDims, 3, + platform::errors::InvalidArgument( + "The Input dim of the SparseMultiheadMatMul should be 3, " + "but it's (%d) now.", + input->getDimensions().nbDims)); + // transpose weight_data from m * n to n * m + auto* input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + + half* half_data = nullptr; + void* w_data = nullptr; + if (with_fp16) { + half_data = new half[weight_t->numel()]; + for (int i = 0; i < weight_t->numel(); i++) { + half_data[i] = static_cast(weight_data[i]); + } + w_data = static_cast(half_data); + } else { + w_data = static_cast(weight_data); + } + + TensorRTEngine::Weight weight{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + static_cast(w_data), static_cast(weight_t->numel())}; + weight.dims.assign({n, m}); + + half* half_bias_data = nullptr; + void* b_data = nullptr; + if (with_fp16) { + half_bias_data = new half[bias_t->numel()]; + for (int i = 0; i < bias_t->numel(); i++) { + half_bias_data[i] = static_cast(bias_data[i]); + } + b_data = static_cast(half_bias_data); + } else { + b_data = static_cast(bias_data); + } + + TensorRTEngine::Weight bias{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + b_data, static_cast(bias_t->numel())}; + + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = 5; + reshape_before_fc_dim.d[0] = 0; + reshape_before_fc_dim.d[1] = 0; + reshape_before_fc_dim.d[2] = 0; + reshape_before_fc_dim.d[3] = 1; + reshape_before_fc_dim.d[4] = 1; + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (op_desc.HasAttr("Input_scale")) { + engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), + in_scale); + } + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("shuffle_before_sparse_multihead_mamul(Output: " + output_name + + ")") + .c_str()); + + // add layer fc + nvinfer1::ILayer* fc_layer = nullptr; + if (op_desc.HasAttr("Input_scale")) { + plugin::SpmmPluginDynamic* plugin = + new_spmm_plugin(&weight, &bias, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } else { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + &weight, &bias, with_fp16 ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT, + n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + } + fc_layer->setName( + ("sparse_multihead_mamul_fc(Output: " + output_name + ")").c_str()); + + // no need to add shuffle after fc, just change it in + // QkvToContextPluginDynamic + + // add qkv to context + int head_size = hidden_out / head_number; + float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); + + std::vector plugin_inputs; + plugin_inputs.push_back(fc_layer->getOutput(0)); + plugin_inputs.push_back(input_bias_qk); + + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + plugin::DynamicPluginTensorRT* plugin = + new plugin::QkvToContextPluginDynamic(hidden_in, head_number, + head_size, scale, with_fp16); + layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); + } + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static shape mode, which " + "is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " + "the shape information to run the dynamic shape mode.")); + } + RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, + SparseMultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 856985119b73e9..570d578a1c60d4 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -48,6 +48,8 @@ struct SimpleOpTypeSetTeller : public Teller { #if IS_TRT_VERSION_GE(8000) teller_set.insert("sparse_fc"); int8_teller_set.insert("sparse_fc"); + teller_set.insert("sparse_multihead_matmul"); + int8_teller_set.insert("sparse_multihead_matmul"); #endif } @@ -1756,9 +1758,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } #if IS_TRT_VERSION_GE(8000) - if (op_type == "sparse_fc") { + if (op_type == "sparse_fc" || op_type == "sparse_multihead_matmul") { if (!with_dynamic_shape) { - VLOG(3) << "the sparse_fc does not support static shape yet"; + VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support " + "static shape yet"; return false; } } diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu index b9cc7e55b7d2af..71ae18be22ba7c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -63,6 +63,12 @@ inline void deserialize_value_size(void const** buffer, size_t* buffer_size, inline float round_scale(float x) { return std::floor(x + 0.5f); } +inline void cudaFreeFunc(void* p) { + if (p) { + cudaFree(p); + } +} + inline void convertAndCopy(const nvinfer1::Weights& src, nvinfer1::DataType type, void* dest) { PADDLE_ENFORCE_EQ(src.type == nvinfer1::DataType::kFLOAT || @@ -252,6 +258,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, weight_scale_(1.0f), weight_compressed_(nullptr), weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), compressed_size_(0), has_bias_(false), bias_(nullptr), @@ -262,8 +269,10 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, 2. (Int8) Calculate scale and scale the weight (on host) 3. Copy weight to device 4. Compress the weight (on device) - 5. Copy the compressed weight to host - 6. Convert bias precision and copy (on host) + 5. Reset the shared_ptr "weight_compressed_dev_global_" to the compressed + weight + 6. Copy the compressed weight to host + 7. Convert bias precision and copy (on host) */ precision_size_ = getElementSize(precision); element_size_ = @@ -310,26 +319,26 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, cudaMemcpy(weight_dev, weight_host.data(), precision_size_ * weight.count, cudaMemcpyHostToDevice); } - spmm_context_.compressMatB(out_dim_, k_, convertTrtType(precision_), weight_dev, &weight_compressed_dev_, &compressed_size_); weight_compressed_ = new char[compressed_size_]; - cudaMemcpy(weight_compressed_, weight_compressed_dev_, compressed_size_, - cudaMemcpyDeviceToHost); - + weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); + cudaMemcpy(weight_compressed_, weight_compressed_dev_global_.get(), + compressed_size_, cudaMemcpyDeviceToHost); has_bias_ = (bias.count != 0); if (has_bias_) { if (bias.count != out_dim) { PADDLE_THROW(paddle::platform::errors::Fatal( "The dimension of bias should be equal to output dimension")); } - PADDLE_ENFORCE_EQ(bias.type, nvinfer1::DataType::kFLOAT, - platform::errors::InvalidArgument( - "SpmmPluginDynamic only supports FLOAT bias")); - - bias_ = new float[out_dim_]; - convertAndCopy(bias, nvinfer1::DataType::kFLOAT, bias_); + if (precision_ == nvinfer1::DataType::kHALF) { + bias_ = new half[out_dim_]; + convertAndCopy(bias, nvinfer1::DataType::kHALF, bias_); + } else { + bias_ = new float[out_dim_]; + convertAndCopy(bias, nvinfer1::DataType::kFLOAT, bias_); + } } cudaFree(weight_dev); @@ -352,7 +361,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, optim_alg_(optim_alg), weight_scale_(1.0f), weight_compressed_(nullptr), - weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), compressed_size_(compressed_size), has_bias_(false), bias_(nullptr), @@ -360,15 +369,15 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, activation_(activation) { /* 1. Copy the compressed weight (on host) - 2. Copy the compressed weight to device - 3. Copy the bias (on host) - 4. (Configured) Copy the bias to device - 5. (Configured) Init cuSPARSELt descriptors + 2. Copy the bias (on host) + 3. (Configured) Copy the bias to device + 4. (Configured) Init cuSPARSELt descriptors */ precision_size_ = getElementSize(precision); element_size_ = (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); - // Each plugin has a copy of compressed weight + // Each plugin has a copy of compressed weight on host, while sharing the + // compressed weights on device using std::shared_ptr weight_compressed_ = new char[compressed_size]; std::copy_n(static_cast(weight_compressed), compressed_size, static_cast(weight_compressed_)); @@ -399,6 +408,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, : layer_name_(name), weight_compressed_(nullptr), weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), bias_(nullptr), bias_dev_(nullptr) { DeserializeValue(&data, &length, &precision_); @@ -423,6 +433,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, compressed_size_); cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_, cudaMemcpyHostToDevice); + weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); if (has_bias_) { bias_ = new float[out_dim_]; @@ -446,9 +457,8 @@ nvinfer1::IPluginV2DynamicExt* SpmmPluginDynamic::clone() const noexcept { weight_compressed_, compressed_size_, bias_, is_configured_, m_max_, optim_alg_, activation_); p->weight_scale_ = weight_scale_; + p->weight_compressed_dev_global_ = weight_compressed_dev_global_; p->setPluginNamespace(namespace_.c_str()); - p->weight_compressed_dev_ = weight_compressed_dev_; - return p; } catch (const std::exception& e) { std::cerr << e.what() << std::endl; @@ -548,24 +558,36 @@ void SpmmPluginDynamic::configurePlugin( platform::errors::InvalidArgument( "precision_ should be equal to inputs[0].desc.type")); const auto& inDims0 = inputs[0].desc.dims; - PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument( - "inDims0.nbDims should be 5")); - PADDLE_ENFORCE_EQ(k_, inDims0.d[2], - platform::errors::InvalidArgument( - "inDims0.d[2] should be equals to k")); - PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( - "inDims0.d[3] should be 1")); - PADDLE_ENFORCE_EQ(inDims0.d[4], 1, platform::errors::InvalidArgument( - "inDims0.d[4] should be 1")); - const int BS = inputs->max.d[0]; - - // The optimal algorighm id is for m = m_max_ - // To Do: configurePlugin takes time when m is changed + if (inDims0.nbDims == 5) { + PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument( + "inDims0.nbDims should be 5")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[2], + platform::errors::InvalidArgument( + "inDims0.d[2] should be equals to k")); + PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( + "inDims0.d[3] should be 1")); + PADDLE_ENFORCE_EQ(inDims0.d[4], 1, platform::errors::InvalidArgument( + "inDims0.d[4] should be 1")); + const int BS = inputs->max.d[0]; + const int Seq = inputs->max.d[1]; + m_max_ = BS * Seq; + } else if (inDims0.nbDims == 4) { + PADDLE_ENFORCE_EQ(inDims0.nbDims, 4, platform::errors::InvalidArgument( + "inDims0.nbDims should be 4")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[1], + platform::errors::InvalidArgument( + "inDims0.d[1] should be equals to k")); + PADDLE_ENFORCE_EQ(inDims0.d[2], 1, platform::errors::InvalidArgument( + "inDims0.d[2] should be 1")); + PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument( + "inDims0.d[3] should be 1")); + const int BS_Seq = inputs->max.d[0]; + m_max_ = BS_Seq; + } if (is_configured_) { return; } - m_max_ = BS; if (has_bias_) { if (inputs->desc.type == nvinfer1::DataType::kINT8) { for (int i = 0; i < out_dim_; ++i) { @@ -596,7 +618,8 @@ void SpmmPluginDynamic::configurePlugin( spmm_context_.workspace_size); paddle::platform::dynload::cusparseLtMatmulSearch( &spmm_context_.handle, &spmm_context_.plan, &alpha, dA, - weight_compressed_dev_, &beta, dC, dC, d_workspace, nullptr, 0); + weight_compressed_dev_global_.get(), &beta, dC, dC, d_workspace, + nullptr, 0); paddle::platform::dynload::cusparseLtMatmulAlgGetAttribute( &spmm_context_.handle, &spmm_context_.alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &optim_alg_, sizeof(optim_alg_)); @@ -624,32 +647,44 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, PADDLE_ENFORCE_EQ(is_configured_, true, platform::errors::InvalidArgument( "The plugin is not configured before enqueue")); - PADDLE_ENFORCE_EQ( - k_, inputDesc->dims.d[2], - platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); + if (inputDesc->dims.nbDims == 5) { + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[2], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); + } else if (inputDesc->dims.nbDims == 4) { + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[1], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[1]")); + } float alpha = 1.0f; float beta = 0.0f; if (inputDesc->type == nvinfer1::DataType::kFLOAT) { const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, + 1); return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kHALF) { const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, + 1); return status != CUSPARSE_STATUS_SUCCESS; } else if (inputDesc->type == nvinfer1::DataType::kINT8) { alpha = inputDesc->scale * weight_scale_ / outputDesc->scale; const auto* const input = static_cast(inputs[0]); auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( &spmm_context_.handle, &spmm_context_.plan, &alpha, input, - weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1); + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, + 1); return status != CUSPARSE_STATUS_SUCCESS; } else { PADDLE_THROW(paddle::platform::errors::Fatal( @@ -713,7 +748,6 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept { SerializeValue(&buffer, compressed_size_); SerializeValue(&buffer, has_bias_); SerializeValue(&buffer, activation_); - char* d = static_cast(buffer); std::copy_n(static_cast(weight_compressed_), compressed_size_, d); @@ -725,10 +759,6 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept { void SpmmPluginDynamic::destroy() noexcept { delete[] reinterpret_cast(weight_compressed_); - if (weight_compressed_dev_) { - cudaFree(weight_compressed_dev_); - weight_compressed_dev_ = nullptr; - } if (has_bias_) { cudaFree(bias_dev_); } diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h index a7edb8dedfa7fd..404fbff18b8c2e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h @@ -39,6 +39,8 @@ #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/platform/dynload/cusparseLt.h" +using namespace std; + namespace paddle { namespace inference { namespace tensorrt { @@ -77,6 +79,7 @@ class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt { const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; @@ -128,7 +131,8 @@ class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt { int optim_alg_; // the index of optimal algorithm float weight_scale_; // record the weight scale from constructor void* weight_compressed_; // host compressed weight - void* weight_compressed_dev_; // device compressed weight + void* weight_compressed_dev_; // device compressed weight + shared_ptr weight_compressed_dev_global_; // shared pointer to the device compressed weight size_t compressed_size_; // size of compressed weight bool has_bias_; // there is bias or not void* bias_; // host bias diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc new file mode 100644 index 00000000000000..59407f486968b2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -0,0 +1,170 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/float16.h" + +using float16 = phi::dtype::float16; +namespace paddle { +namespace inference { +namespace tensorrt { + +class TensorRTDynamicEngineTest : public ::testing::Test { + protected: + void SetUp() override { + ctx_ = new platform::CUDADeviceContext(platform::CUDAPlace(0)); + ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(platform::CUDAPlace(0), ctx_->stream()) + .get()); + ctx_->SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + ctx_->SetZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(platform::CUDAPlace(0)) + .get()); + ctx_->SetPinnedAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CUDAPinnedPlace()) + .get()); + ctx_->PartialInitWithAllocator(); + + std::map> min_input_shape = { + {"input", {16, 32, 1, 1}}}; + std::map> max_input_shape = { + {"input", {16, 32, 1, 1}}}; + std::map> optim_input_shape = { + {"input", {16, 32, 1, 1}}}; + + engine_ = + new TensorRTEngine(16, 1 << 10, AnalysisConfig::Precision::kHalf, + nullptr, 0, min_input_shape, max_input_shape, + optim_input_shape, false, NaiveLogger::Global()); + engine_->InitNetwork(); + } + + void TearDown() override { + if (engine_) { + delete engine_; + engine_ = nullptr; + } + } + + void PrepareInputOutput(const std::vector &input, + std::vector output_shape) { + paddle::framework::TensorFromVector(input, *ctx_, &input_); + output_.Resize(phi::make_ddim(output_shape)); + } + + void GetOutput(std::vector *output) { + paddle::framework::TensorToVector(output_, *ctx_, output); + } + + protected: + framework::Tensor input_; + framework::Tensor output_; + TensorRTEngine *engine_; + platform::CUDADeviceContext *ctx_; +}; + +TEST_F(TensorRTDynamicEngineTest, test_spmm) { + // Weight in CPU memory. + float16 raw_weight[512]; + for (int i = 0; i < 128; i++) { + if (i % 16 <= 7) { + raw_weight[4 * i] = float16(1.0); + raw_weight[4 * i + 1] = float16(0.0); + raw_weight[4 * i + 2] = float16(0.0); + raw_weight[4 * i + 3] = float16(4.0); + } else { + raw_weight[4 * i] = float16(0.0); + raw_weight[4 * i + 1] = float16(2.0); + raw_weight[4 * i + 2] = float16(3.0); + raw_weight[4 * i + 3] = float16(0.0); + } + } + float16 raw_bias[16] = {float16(0), float16(1), float16(0), float16(2), + float16(0), float16(3), float16(0), float16(4), + float16(0), float16(5), float16(0), float16(6), + float16(0), float16(7), float16(0), float16(8)}; + std::vector buffers(2); // TRT binded inputs + TensorRTEngine::Weight weight(nvinfer1::DataType::kHALF, raw_weight, 512); + TensorRTEngine::Weight bias(nvinfer1::DataType::kHALF, raw_bias, 16); + std::cout << "with_dynamic_shape: " << engine_->with_dynamic_shape() + << std::endl; + auto *x = engine_->DeclareInput("input", nvinfer1::DataType::kHALF, + nvinfer1::Dims4{-1, 32, 1, 1}); + + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + + plugin::SpmmPluginDynamic *plugin = new plugin::SpmmPluginDynamic( + "CustomSpmmPluginDynamic", nvinfer1::DataType::kHALF, 16, weight.get(), + bias.get(), act); + std::vector plugin_inputs; + plugin_inputs.emplace_back(x); + auto fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + + LOG(INFO) << "create weights"; + PADDLE_ENFORCE_NOT_NULL(fc_layer, platform::errors::InvalidArgument( + "TRT SPMM layer building failed.")); + + engine_->DeclareOutput(fc_layer, 0, "y"); + engine_->FreezeNetwork(); + ASSERT_EQ(engine_->engine()->getNbBindings(), 2); + + std::vector x_v(512); + for (int i = 0; i < 128; i++) { + x_v[4 * i] = float16(1.0); + x_v[4 * i + 1] = float16(2.0); + x_v[4 * i + 2] = float16(3.0); + x_v[4 * i + 3] = float16(4.0); + } + + std::vector y_cpu; + PrepareInputOutput(x_v, {16, 16}); + + auto *x_v_gpu_data = input_.mutable_data(ctx_->GetPlace()); + auto *y_gpu_data = output_.mutable_data(ctx_->GetPlace()); + + buffers[0] = reinterpret_cast(x_v_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + + engine_->Execute(16, &buffers, ctx_->stream()); + LOG(INFO) << "to get output"; + GetOutput(&y_cpu); + + auto dims = engine_->GetITensor("y")->getDimensions(); + ASSERT_EQ(dims.nbDims, 4); + ASSERT_EQ(dims.d[1], 16); + + ASSERT_EQ(y_cpu[0], 136); + ASSERT_EQ(y_cpu[1], 105); + ASSERT_EQ(y_cpu[32], 136); + ASSERT_EQ(y_cpu[64], 136); + ASSERT_EQ(y_cpu[96], 136); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle