Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(//core/converters): Add normalization plugin #323

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cc_library(
"impl/expand.cpp",
"impl/linear.cpp",
"impl/matrix_multiply.cpp",
"impl/normalize.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/shuffle.cpp",
Expand Down
73 changes: 73 additions & 0 deletions core/conversion/converters/impl/normalize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include "NvInfer.h"
#include "NvInferRuntimeCommon.h"
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
#include "plugins/normalize_plugin.h"
#include "torch/torch.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

/*
* Helper functions
*/
void create_plugin(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* in,
int64_t order,
std::vector<int64_t> axes,
bool keep_dims,
const char* name) {
LOG_WARNING("Normalize layer will be run through ATen, not TensorRT. Performance may be lower than expected");

auto creator = new plugins::NormalizePluginCreator();
auto inputnbDims = in->getDimensions().nbDims;
for (int64_t i = 0; i < axes.size(); i++) {
if (axes[i] < 0) {
axes[i] += inputnbDims;
}
if (axes[i] > inputnbDims - 1) {
TRTORCH_THROW_ERROR("Axis of normalization layer cannot exceed input rank");
}
}

auto plugin = creator->createPlugin(name, order, axes, keep_dims);

auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
TRTORCH_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n);

normalize_layer->setName(util::node_info(n).c_str());

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], normalize_layer->getOutput(0));

LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
}

auto normalize_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto order = args[1].unwrapToScalar().to<int64_t>();
auto axes = args[2].unwrapToIntList().vec();
auto keep_dims = args[3].unwrapToBool();
LOG_DEBUG("Order of normalize_plugin: " << order);
LOG_DEBUG("Axis: " << axes);
LOG_DEBUG("keep_dims: " << keep_dims);
create_plugin(ctx, n, in, order, axes, keep_dims, "Normalize");
return true;
}

});

} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
9 changes: 6 additions & 3 deletions core/conversion/converters/impl/plugins/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ config_setting(
cc_library(
name = "plugins",
hdrs = [
"interpolate_plugin.h"
"interpolate_plugin.h",
"normalize_plugin.h"
],
srcs = [
"interpolate_plugin.cpp"
"interpolate_plugin.cpp",
"normalize_plugin.cpp"
],
deps = [
"@tensorrt//:nvinfer",
Expand All @@ -37,5 +39,6 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
pkg_tar(
name = "include",
package_dir = "core/conversion/converters/impl/plugins",
srcs = ["interpolate_plugin.h"],
srcs = ["interpolate_plugin.h",
"normalize_plugin.h"],
)
277 changes: 277 additions & 0 deletions core/conversion/converters/impl/plugins/normalize_plugin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
#include "normalize_plugin.h"

using namespace nvinfer1;

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace plugins {

/*
* NormalizePlugin class implementations
*/

NormalizePlugin::NormalizePlugin(int64_t order, std::vector<int64_t> axes, bool keep_dims)
: order_(order), axes_(axes), keep_dims_(keep_dims) {}

NormalizePlugin::NormalizePlugin(const char* data, size_t length) {
std::istringstream data_stream(std::string(data, length));

torch::serialize::InputArchive input_archive;
input_archive.load_from(data_stream);
{
torch::IValue value;
input_archive.read("order", value);
order_ = value.toInt();
}
{
torch::IValue value;
input_archive.read("axes", value);
axes_ = value.toIntVector();
}
{
torch::IValue value;
input_archive.read("keep_dims", value);
keep_dims_ = value.toBool();
}
}

int NormalizePlugin::getNbOutputs() const {
return 1;
}

const char* NormalizePlugin::getPluginType() const {
return "Normalize";
}

const char* NormalizePlugin::getPluginVersion() const {
return "1";
}

const char* NormalizePlugin::getPluginNamespace() const {
return "";
}

nvinfer1::IPluginV2DynamicExt* NormalizePlugin::clone() const {
return new NormalizePlugin(order_, axes_, keep_dims_);
}

nvinfer1::DimsExprs NormalizePlugin::getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) {
nvinfer1::DimsExprs output;
output.nbDims = keep_dims_ ? inputs[0].nbDims : inputs[0].nbDims - axes_.size();

// For order-0 norm, when the norm dimension is None, it should normalize across all dimensions.
// TODO: For dim=None, the axes_ passed would have [0, 0, 0] which is obtained through loop counter in TRTorch.
// Resolve this. For dim=None case, change the axes_ inplace to range(0, axes_.size())
bool isAxisNone =
std::all_of(axes_.begin(), axes_.end(), [](int64_t i) { return i == 0; }) && (axes_.size() == inputs[0].nbDims);
if (isAxisNone) {
std::iota(axes_.data(), axes_.data() + axes_.size(), 0);
}
int64_t out_idx = 0;
for (int64_t i = 0; i < inputs[0].nbDims; i++) {
if (std::find(axes_.begin(), axes_.end(), i) != axes_.end()) {
if (keep_dims_) {
output.d[out_idx] = exprBuilder.constant(1);
out_idx += 1;
}
} else {
if (!isAxisNone) {
output.d[out_idx] = exprBuilder.constant(inputs[0].d[i]->getConstantValue());
} else {
output.d[out_idx] = exprBuilder.constant(1);
}
out_idx += 1;
}
}

return output;
}

nvinfer1::DataType NormalizePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs)
const {
return DataType::kFLOAT;
}

int NormalizePlugin::initialize() {
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
tensor_options_ = tensor_options_.device(c10::kCUDA);
#else
tensor_options_ = tensor_options_.device(c10::kCPU);
#endif

// c10::kFloat = FLOAT32
tensor_options_ = tensor_options_.dtype(c10::kFloat);

return 0;
}

void NormalizePlugin::serialize(void* buffer) const {
std::string data = serializeToString();
size_t size = getSerializationSize();
data.copy((char*)buffer, size);
}

std::string NormalizePlugin::serializeToString() const {
torch::serialize::OutputArchive output_archive;

output_archive.write("order", torch::IValue(order_));
output_archive.write("axes", torch::IValue(axes_));
output_archive.write("keep_dims", torch::IValue(keep_dims_));
std::ostringstream data_str;
output_archive.save_to(data_str);

return data_str.str();
}

size_t NormalizePlugin::getSerializationSize() const {
return serializeToString().size();
}

bool NormalizePlugin::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) {
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to normalize plugin");
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to normalize plugin");

const PluginTensorDesc& in = inOut[0];

if (pos == 0) {
return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR);
}

// pos == 1, accessing information about output tensor
const PluginTensorDesc& out = inOut[1];

return (in.type == out.type) && (in.format == out.format);
}

void NormalizePlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) {
dtype_ = DataType::kFLOAT;
}

size_t NormalizePlugin::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const {
return 0;
}

int NormalizePlugin::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) {
// TRT <= 7.0
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_);
at::Tensor output = at::from_blob(outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_);

at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
at::cuda::CUDAStreamGuard torch_guard(torch_stream);

cudaEvent_t event;
cudaEventCreate(&event);
cudaEventRecord(event, stream);

cudaStreamWaitEvent(torch_stream.stream(), event, 0);
at::Tensor result = at::norm(input, order_, axes, keep_dims_);
output.copy_(result);
cudaEvent_t torch_event;
cudaEventCreate(&torch_event);
cudaEventRecord(torch_event, torch_stream.stream());

cudaStreamWaitEvent(stream, torch_event, 0);

cudaEventDestroy(event);
cudaEventDestroy(torch_event);
return 0;
#else
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
// kernels HACK: WAR because there is a segfault if you try to create a CUDA
// Tensor in the context of TensorRT execution
float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float));
cudaMemcpyAsync(
input_blob,
static_cast<const void*>(inputs[0]),
util::volume(inputDesc->dims) * sizeof(float),
cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);

at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
at::Tensor output = at::norm(input, order_, axes_, keep_dims_);

cudaMemcpyAsync(
outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);

free(input_blob);
return 0;
#endif
}

/*
* NormalizePluginCreator class implementations
*/
const char* NormalizePluginCreator::getPluginNamespace() const {
return "";
}

const char* NormalizePluginCreator::getPluginName() const {
return "Normalize";
}

const char* NormalizePluginCreator::getPluginVersion() const {
return "1";
}

nvinfer1::IPluginV2* NormalizePluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {
return nullptr;
}

NormalizePlugin* NormalizePluginCreator::createPlugin(
const char* name,
int64_t order,
std::vector<int64_t> axes,
bool keep_dims) {
name_ = name;
return new NormalizePlugin(order, axes, keep_dims);
}

nvinfer1::IPluginV2* NormalizePluginCreator::deserializePlugin(
const char* name,
const void* serialData,
size_t serialLength) {
name_ = name;
return new NormalizePlugin((const char*)serialData, serialLength);
}

const nvinfer1::PluginFieldCollection* NormalizePluginCreator::getFieldNames() {
return nullptr;
}

REGISTER_TENSORRT_PLUGIN(NormalizePluginCreator);

} // namespace plugins
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
Loading