From be6490079ac2c592d24134732648233a7bf96399 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 2 Feb 2021 13:44:39 -0800 Subject: [PATCH] Add normalization plugin --- core/conversion/converters/BUILD | 1 + core/conversion/converters/impl/normalize.cpp | 73 +++++ core/conversion/converters/impl/plugins/BUILD | 9 +- .../impl/plugins/normalize_plugin.cpp | 277 ++++++++++++++++++ .../impl/plugins/normalize_plugin.h | 138 +++++++++ tests/core/conversion/converters/BUILD | 5 + .../converters/test_element_wise.cpp | 104 +++---- .../conversion/converters/test_normalize.cpp | 77 +++++ 8 files changed, 629 insertions(+), 55 deletions(-) create mode 100644 core/conversion/converters/impl/normalize.cpp create mode 100644 core/conversion/converters/impl/plugins/normalize_plugin.cpp create mode 100644 core/conversion/converters/impl/plugins/normalize_plugin.h create mode 100644 tests/core/conversion/converters/test_normalize.cpp diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 11dd04abde..a3d0dcb0b6 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -42,6 +42,7 @@ cc_library( "impl/expand.cpp", "impl/linear.cpp", "impl/matrix_multiply.cpp", + "impl/normalize.cpp", "impl/pooling.cpp", "impl/reduce.cpp", "impl/shuffle.cpp", diff --git a/core/conversion/converters/impl/normalize.cpp b/core/conversion/converters/impl/normalize.cpp new file mode 100644 index 0000000000..d4f03494bb --- /dev/null +++ b/core/conversion/converters/impl/normalize.cpp @@ -0,0 +1,73 @@ +#include "NvInfer.h" +#include "NvInferRuntimeCommon.h" +#include "core/conversion/converters/converters.h" +#include "core/util/prelude.h" +#include "plugins/normalize_plugin.h" +#include "torch/torch.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +/* + * Helper functions + */ +void create_plugin( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* in, + int64_t order, + std::vector axes, + bool keep_dims, + const char* name) { + LOG_WARNING("Normalize layer will be run through ATen, not TensorRT. Performance may be lower than expected"); + + auto creator = new plugins::NormalizePluginCreator(); + auto inputnbDims = in->getDimensions().nbDims; + for (int64_t i = 0; i < axes.size(); i++) { + if (axes[i] < 0) { + axes[i] += inputnbDims; + } + if (axes[i] > inputnbDims - 1) { + TRTORCH_THROW_ERROR("Axis of normalization layer cannot exceed input rank"); + } + } + + auto plugin = creator->createPlugin(name, order, axes, keep_dims); + + auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *plugin); + TRTORCH_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n); + + normalize_layer->setName(util::node_info(n).c_str()); + + auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], normalize_layer->getOutput(0)); + + LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions()); +} + +auto normalize_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + auto order = args[1].unwrapToScalar().to(); + auto axes = args[2].unwrapToIntList().vec(); + auto keep_dims = args[3].unwrapToBool(); + LOG_DEBUG("Order of normalize_plugin: " << order); + LOG_DEBUG("Axis: " << axes); + LOG_DEBUG("keep_dims: " << keep_dims); + create_plugin(ctx, n, in, order, axes, keep_dims, "Normalize"); + return true; + } + + }); + +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/plugins/BUILD b/core/conversion/converters/impl/plugins/BUILD index b89685ac6e..e0fc1f59d0 100755 --- a/core/conversion/converters/impl/plugins/BUILD +++ b/core/conversion/converters/impl/plugins/BUILD @@ -10,10 +10,12 @@ config_setting( cc_library( name = "plugins", hdrs = [ - "interpolate_plugin.h" + "interpolate_plugin.h", + "normalize_plugin.h" ], srcs = [ - "interpolate_plugin.cpp" + "interpolate_plugin.cpp", + "normalize_plugin.cpp" ], deps = [ "@tensorrt//:nvinfer", @@ -37,5 +39,6 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar") pkg_tar( name = "include", package_dir = "core/conversion/converters/impl/plugins", - srcs = ["interpolate_plugin.h"], + srcs = ["interpolate_plugin.h", + "normalize_plugin.h"], ) diff --git a/core/conversion/converters/impl/plugins/normalize_plugin.cpp b/core/conversion/converters/impl/plugins/normalize_plugin.cpp new file mode 100644 index 0000000000..b03018e67b --- /dev/null +++ b/core/conversion/converters/impl/plugins/normalize_plugin.cpp @@ -0,0 +1,277 @@ +#include "normalize_plugin.h" + +using namespace nvinfer1; + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace plugins { + +/* + * NormalizePlugin class implementations + */ + +NormalizePlugin::NormalizePlugin(int64_t order, std::vector axes, bool keep_dims) + : order_(order), axes_(axes), keep_dims_(keep_dims) {} + +NormalizePlugin::NormalizePlugin(const char* data, size_t length) { + std::istringstream data_stream(std::string(data, length)); + + torch::serialize::InputArchive input_archive; + input_archive.load_from(data_stream); + { + torch::IValue value; + input_archive.read("order", value); + order_ = value.toInt(); + } + { + torch::IValue value; + input_archive.read("axes", value); + axes_ = value.toIntVector(); + } + { + torch::IValue value; + input_archive.read("keep_dims", value); + keep_dims_ = value.toBool(); + } +} + +int NormalizePlugin::getNbOutputs() const { + return 1; +} + +const char* NormalizePlugin::getPluginType() const { + return "Normalize"; +} + +const char* NormalizePlugin::getPluginVersion() const { + return "1"; +} + +const char* NormalizePlugin::getPluginNamespace() const { + return ""; +} + +nvinfer1::IPluginV2DynamicExt* NormalizePlugin::clone() const { + return new NormalizePlugin(order_, axes_, keep_dims_); +} + +nvinfer1::DimsExprs NormalizePlugin::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) { + nvinfer1::DimsExprs output; + output.nbDims = keep_dims_ ? inputs[0].nbDims : inputs[0].nbDims - axes_.size(); + + // For order-0 norm, when the norm dimension is None, it should normalize across all dimensions. + // TODO: For dim=None, the axes_ passed would have [0, 0, 0] which is obtained through loop counter in TRTorch. + // Resolve this. For dim=None case, change the axes_ inplace to range(0, axes_.size()) + bool isAxisNone = + std::all_of(axes_.begin(), axes_.end(), [](int64_t i) { return i == 0; }) && (axes_.size() == inputs[0].nbDims); + if (isAxisNone) { + std::iota(axes_.data(), axes_.data() + axes_.size(), 0); + } + int64_t out_idx = 0; + for (int64_t i = 0; i < inputs[0].nbDims; i++) { + if (std::find(axes_.begin(), axes_.end(), i) != axes_.end()) { + if (keep_dims_) { + output.d[out_idx] = exprBuilder.constant(1); + out_idx += 1; + } + } else { + if (!isAxisNone) { + output.d[out_idx] = exprBuilder.constant(inputs[0].d[i]->getConstantValue()); + } else { + output.d[out_idx] = exprBuilder.constant(1); + } + out_idx += 1; + } + } + + return output; +} + +nvinfer1::DataType NormalizePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) + const { + return DataType::kFLOAT; +} + +int NormalizePlugin::initialize() { +#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + tensor_options_ = tensor_options_.device(c10::kCUDA); +#else + tensor_options_ = tensor_options_.device(c10::kCPU); +#endif + + // c10::kFloat = FLOAT32 + tensor_options_ = tensor_options_.dtype(c10::kFloat); + + return 0; +} + +void NormalizePlugin::serialize(void* buffer) const { + std::string data = serializeToString(); + size_t size = getSerializationSize(); + data.copy((char*)buffer, size); +} + +std::string NormalizePlugin::serializeToString() const { + torch::serialize::OutputArchive output_archive; + + output_archive.write("order", torch::IValue(order_)); + output_archive.write("axes", torch::IValue(axes_)); + output_archive.write("keep_dims", torch::IValue(keep_dims_)); + std::ostringstream data_str; + output_archive.save_to(data_str); + + return data_str.str(); +} + +size_t NormalizePlugin::getSerializationSize() const { + return serializeToString().size(); +} + +bool NormalizePlugin::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) { + TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output"); + TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to normalize plugin"); + TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to normalize plugin"); + + const PluginTensorDesc& in = inOut[0]; + + if (pos == 0) { + return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR); + } + + // pos == 1, accessing information about output tensor + const PluginTensorDesc& out = inOut[1]; + + return (in.type == out.type) && (in.format == out.format); +} + +void NormalizePlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) { + dtype_ = DataType::kFLOAT; +} + +size_t NormalizePlugin::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const { + return 0; +} + +int NormalizePlugin::enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) { + // TRT <= 7.0 +#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_); + at::Tensor output = at::from_blob(outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_); + + at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); + at::cuda::CUDAStreamGuard torch_guard(torch_stream); + + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, stream); + + cudaStreamWaitEvent(torch_stream.stream(), event, 0); + at::Tensor result = at::norm(input, order_, axes, keep_dims_); + output.copy_(result); + cudaEvent_t torch_event; + cudaEventCreate(&torch_event); + cudaEventRecord(torch_event, torch_stream.stream()); + + cudaStreamWaitEvent(stream, torch_event, 0); + + cudaEventDestroy(event); + cudaEventDestroy(torch_event); + return 0; +#else + // TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen + // kernels HACK: WAR because there is a segfault if you try to create a CUDA + // Tensor in the context of TensorRT execution + float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float)); + cudaMemcpyAsync( + input_blob, + static_cast(inputs[0]), + util::volume(inputDesc->dims) * sizeof(float), + cudaMemcpyDeviceToHost, + stream); + cudaStreamSynchronize(stream); + + at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_); + at::Tensor output = at::norm(input, order_, axes_, keep_dims_); + + cudaMemcpyAsync( + outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaStreamSynchronize(stream); + + free(input_blob); + return 0; +#endif +} + +/* + * NormalizePluginCreator class implementations + */ +const char* NormalizePluginCreator::getPluginNamespace() const { + return ""; +} + +const char* NormalizePluginCreator::getPluginName() const { + return "Normalize"; +} + +const char* NormalizePluginCreator::getPluginVersion() const { + return "1"; +} + +nvinfer1::IPluginV2* NormalizePluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) { + return nullptr; +} + +NormalizePlugin* NormalizePluginCreator::createPlugin( + const char* name, + int64_t order, + std::vector axes, + bool keep_dims) { + name_ = name; + return new NormalizePlugin(order, axes, keep_dims); +} + +nvinfer1::IPluginV2* NormalizePluginCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) { + name_ = name; + return new NormalizePlugin((const char*)serialData, serialLength); +} + +const nvinfer1::PluginFieldCollection* NormalizePluginCreator::getFieldNames() { + return nullptr; +} + +REGISTER_TENSORRT_PLUGIN(NormalizePluginCreator); + +} // namespace plugins +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/plugins/normalize_plugin.h b/core/conversion/converters/impl/plugins/normalize_plugin.h new file mode 100644 index 0000000000..35a06b352c --- /dev/null +++ b/core/conversion/converters/impl/plugins/normalize_plugin.h @@ -0,0 +1,138 @@ +#ifndef TRTORCH_NORMALIZE_PLUGIN_H +#define TRTORCH_NORMALIZE_PLUGIN_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "core/util/prelude.h" +#include "torch/torch.h" + +using namespace nvinfer1; + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace plugins { + +class NormalizePlugin : public nvinfer1::IPluginV2DynamicExt { + private: + at::TensorOptions tensor_options_; + DataType dtype_; + int64_t order_; + std::vector axes_; + bool keep_dims_; + + protected: + // To prevent compiler warnings + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; + + public: + NormalizePlugin(int64_t order, std::vector axes, bool keep_dims); + + NormalizePlugin(const char* data, size_t length); + + NormalizePlugin() = delete; + + int getNbOutputs() const override; + + const char* getPluginType() const override; + + const char* getPluginVersion() const override; + + const char* getPluginNamespace() const override; + + void setPluginNamespace(const char* pluginNamespace) override{}; + + nvinfer1::IPluginV2DynamicExt* clone() const override; + + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) override; + + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; + + int initialize() override; + + void terminate() override {} + + void serialize(void* buffer) const; + + std::string serializeToString() const; + + size_t getSerializationSize() const override; + + void destroy() override {} + + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) + override; + + void configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override; + + size_t getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override; + + int enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) override; +}; + +class NormalizePluginCreator : public nvinfer1::IPluginCreator { + private: + std::string name_; + + public: + NormalizePluginCreator() = default; + + const char* getPluginNamespace() const override; + + void setPluginNamespace(const char* libNamespace) override{}; + + const char* getPluginName() const override; + + const char* getPluginVersion() const override; + + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override; + + NormalizePlugin* createPlugin(const char* name, int64_t order, std::vector axes, bool keep_dims); + + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; + + const nvinfer1::PluginFieldCollection* getFieldNames() override; +}; + +} // namespace plugins +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch + +#endif // TRTORCH_NORMALIZE_PLUGIN_H diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 4eba4f4c9f..039c525fcf 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -63,6 +63,10 @@ converter_test( name = "test_interpolate" ) +converter_test( + name = "test_normalize" +) + converter_test( name = "test_select" ) @@ -103,6 +107,7 @@ test_suite( ":test_softmax", ":test_unary", ":test_interpolate", + ":test_normalize", ":test_select", ":test_stack", ":test_lstm_cell", diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index 4b2fe9cc2d..5431e9aa2a 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -172,117 +172,117 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) { pointwise_test_helper(graph, true, false, {3, 4, 2}); ; -TEST(Converters, ATenClampMinConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenClampMinConvertsCorrectly) { + const auto graph = R"IR( graph(%x.1 : Tensor): %2 : int = prim::Constant[value=-2]() %3 : None = prim::Constant() %4 : Tensor = aten::clamp(%x.1, %2, %3) return (%4))IR"; - pointwise_test_helper(graph, true); -} + pointwise_test_helper(graph, true); + } -TEST(Converters, ATenClampMaxConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenClampMaxConvertsCorrectly) { + const auto graph = R"IR( graph(%x.1 : Tensor): %2 : int = prim::Constant[value=3]() %3 : None = prim::Constant() %4 : Tensor = aten::clamp(%x.1, %3, %2) return (%4))IR"; - pointwise_test_helper(graph, true); -} + pointwise_test_helper(graph, true); + } -TEST(Converters, ATenClampMinMaxConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenClampMinMaxConvertsCorrectly) { + const auto graph = R"IR( graph(%x.1 : Tensor): %2 : int = prim::Constant[value=3]() %3 : int = prim::Constant[value=-2]() %4 : Tensor = aten::clamp(%x.1, %3, %2) return (%4))IR"; - pointwise_test_helper(graph, true); -} + pointwise_test_helper(graph, true); + } -TEST(Converters, ATenGreaterThanConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenGreaterThanConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : Tensor = aten::gt(%0, %1) return (%2))IR"; - pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); -} + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); + } -TEST(Converters, ATenGreaterThanScalarConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenGreaterThanScalarConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor): %scalar : float = prim::Constant[value=3]() %2 : Tensor = aten::gt(%0, %scalar) return (%2))IR"; - pointwise_test_helper(graph, true, false, {5, 5}); -} + pointwise_test_helper(graph, true, false, {5, 5}); + } -TEST(Converters, ATenLessThanConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenLessThanConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : Tensor = aten::lt(%0, %1) return (%2))IR"; - pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); -} + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); + } -TEST(Converters, ATenLessThanScalarConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenLessThanScalarConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor): %scalar : float = prim::Constant[value=3]() %2 : Tensor = aten::lt(%0, %scalar) return (%2))IR"; - pointwise_test_helper(graph, true, false, {5, 5}); -} + pointwise_test_helper(graph, true, false, {5, 5}); + } -TEST(Converters, ATenEqualConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenEqualConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : Tensor = aten::eq(%0, %1) return (%2))IR"; - pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); -} + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); + } -TEST(Converters, ATenEqualScalarConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenEqualScalarConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor): %scalar : float = prim::Constant[value=3]() %2 : Tensor = aten::eq(%0, %scalar) return (%2))IR"; - pointwise_test_helper(graph, true, false, {5, 5}); -} + pointwise_test_helper(graph, true, false, {5, 5}); + } -TEST(Converters, ATenGEConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenGEConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : Tensor = aten::ge(%0, %1) return (%2))IR"; - pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); -} + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); + } -TEST(Converters, ATenGEScalarConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenGEScalarConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor): %scalar : float = prim::Constant[value=3]() %2 : Tensor = aten::ge(%0, %scalar) return (%2))IR"; - pointwise_test_helper(graph, true, false, {5, 5}); -} + pointwise_test_helper(graph, true, false, {5, 5}); + } -TEST(Converters, ATenLEConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenLEConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : Tensor = aten::le(%0, %1) return (%2))IR"; - pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); -} + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}); + } -TEST(Converters, ATenLEScalarConvertsCorrectly) { - const auto graph = R"IR( + TEST(Converters, ATenLEScalarConvertsCorrectly) { + const auto graph = R"IR( graph(%0 : Tensor): %scalar : float = prim::Constant[value=3]() %2 : Tensor = aten::le(%0, %scalar) return (%2))IR"; - pointwise_test_helper(graph, true, false, {5, 5}); -} + pointwise_test_helper(graph, true, false, {5, 5}); + } diff --git a/tests/core/conversion/converters/test_normalize.cpp b/tests/core/conversion/converters/test_normalize.cpp new file mode 100644 index 0000000000..a0c9a623c4 --- /dev/null +++ b/tests/core/conversion/converters/test_normalize.cpp @@ -0,0 +1,77 @@ +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +#define ATEN_INTERPOLATE_TESTS(name, graph_src, input_shape) \ + TEST(Converters, name##StaticConvertsCorrectly) { \ + const auto graph = graph_src; \ + \ + auto g = std::make_shared(); \ + torch::jit::parseIR(graph, &*g); \ + \ + auto in = at::randint(1, 10, input_shape, {at::kCUDA}); \ + auto jit_in = at::clone(in); \ + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); \ + \ + auto trt_in = at::clone(in); \ + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + \ + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); \ + auto trt = trt_results[0].reshape(jit_results[0].sizes()); \ + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); \ + } \ + \ + TEST(Converters, name##DynamicConvertsCorrectly) { \ + const auto graph = graph_src; \ + \ + auto g = std::make_shared(); \ + torch::jit::parseIR(graph, &*g); \ + \ + auto in = at::randint(1, 10, input_shape, {at::kCUDA}); \ + auto jit_in = at::clone(in); \ + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); \ + \ + auto trt_in = at::clone(in); \ + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + \ + auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); \ + auto trt = trt_results[0].reshape(jit_results[0].sizes()); \ + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); \ + } + +ATEN_INTERPOLATE_TESTS( + ATenNormOrder1RemoveDims, + R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 2]]() + %3 : int = prim::Constant[value=1]() + %4 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::norm(%x.1, %3, %2, %4) + return (%5))IR", + std::vector({3, 4, 3})); + +ATEN_INTERPOLATE_TESTS( + ATenNormOrder2RemoveDims, + R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 2]]() + %3 : int = prim::Constant[value=2]() + %4 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::norm(%x.1, %3, %2, %4) + return (%5))IR", + std::vector({3, 4, 3})); + +ATEN_INTERPOLATE_TESTS( + ATenNormOrder2KeepDims, + R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[1]]() + %3 : int = prim::Constant[value=2]() + %4 : bool = prim::Constant[value=1]() + %5 : Tensor = aten::norm(%x.1, %3, %2, %4) + return (%5))IR", + std::vector({3, 4, 3}));