Skip to content

Commit

Permalink
feat(//core/conversion/converters/impl/plugins): Created interpolate …
Browse files Browse the repository at this point in the history
…plugin, works for mode='linear'

Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer committed Jun 18, 2020
1 parent a0848b1 commit 205ab99
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
44 changes: 18 additions & 26 deletions core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
// #include <string>
// #include <iostream>
// #include <sstream>
// #include <ATen/ATen.h>
// #include <ATen/cuda/CUDAEvent.h>
// #include <cuda_runtime_api.h>
// #include <vector>
// #include <cudnn.h>

// #include "core/util/prelude.h"
// #include "torch/torch.h"
// #include "NvInfer.h"

#include "interpolate_plugin.h"

using namespace nvinfer1;
Expand Down Expand Up @@ -80,27 +67,30 @@ int InterpolatePlugin::getNbOutputs() const {
}

const char* InterpolatePlugin::getPluginType() const {
return "Interpolate_TRTorch";
return "Interpolate";
}

const char* InterpolatePlugin::getPluginVersion() const{
return "1";
}

const char* InterpolatePlugin::getPluginNamespace() const {
return "trtorch";
return "";
}

int InterpolatePlugin::getTensorRTVersion() const {
return NV_TENSORRT_MAJOR;
}

nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
}

nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) {
return inputs[0];
nvinfer1::DimsExprs output(inputs[0]);

for (unsigned int i = 0; i < out_shape.size(); i++) {
output.d[i] = exprBuilder.constant(out_shape[i]);
}

return output;
}

nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
Expand All @@ -109,6 +99,8 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer

int InterpolatePlugin::initialize() {
tensor_options = tensor_options.device(c10::kCUDA);

// c10::kFloat = FLOAT32
tensor_options = tensor_options.dtype(c10::kFloat);

return 0;
Expand Down Expand Up @@ -164,6 +156,7 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
void *const *outputs, void *workspace,
cudaStream_t stream) {

at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);

Expand Down Expand Up @@ -200,13 +193,11 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
* InterpolatePluginCreator class implementations
*/
const char* InterpolatePluginCreator::getPluginNamespace() const {
return "trtorch";
return "";
}

void InterpolatePluginCreator::setPluginNamespace(const char* libNamespace) {}

const char* InterpolatePluginCreator::getPluginName() const {
return "interpolate";
return "Interpolate";
}

const char* InterpolatePluginCreator::getPluginVersion() const {
Expand All @@ -217,7 +208,9 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
return nullptr;
}

nvinfer1::IPluginV2DynamicExt* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
InterpolatePlugin* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape,
std::vector<int64_t> size,
std::string mode, bool align_corners) {
name = name;
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
}
Expand All @@ -238,5 +231,4 @@ REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch

} // namespace trtorch
8 changes: 3 additions & 5 deletions core/conversion/converters/impl/plugins/interpolate_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {

const char* getPluginNamespace() const override;

void setPluginNamespace(const char* pluginNamespace) {}

int getTensorRTVersion() const override;
void setPluginNamespace(const char* pluginNamespace) override {};

nvinfer1::IPluginV2DynamicExt* clone() const override;

Expand Down Expand Up @@ -107,15 +105,15 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {

const char* getPluginNamespace() const override;

void setPluginNamespace(const char* libNamespace) override;
void setPluginNamespace(const char* libNamespace) override {};

const char* getPluginName() const override;

const char* getPluginVersion() const override;

nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override;

nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);
InterpolatePlugin* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);

nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override;

Expand Down

0 comments on commit 205ab99

Please sign in to comment.