Skip to content

Commit

Permalink
Update batchedNMS plugin to IPluginV2DynamicExt
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
rajeevsrao committed Jul 27, 2020
1 parent 11bdf61 commit 0e1638d
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 71 deletions.
157 changes: 102 additions & 55 deletions plugin/batchedNMSPlugin/batchedNMSPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,43 +63,84 @@ int BatchedNMSPlugin::initialize()

void BatchedNMSPlugin::terminate() {}

Dims BatchedNMSPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
DimsExprs BatchedNMSPlugin::getOutputDimensions(
int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder)
{
ASSERT(nbInputDims == 2);
ASSERT(index >= 0 && index < this->getNbOutputs());
ASSERT(inputs[0].nbDims == 3);
ASSERT(inputs[1].nbDims == 2 || (inputs[1].nbDims == 3 && inputs[1].d[2] == 1));
// boxesSize: number of box coordinates for one sample
boxesSize = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
// scoresSize: number of scores for one sample
scoresSize = inputs[1].d[0] * inputs[1].d[1];
ASSERT(nbInputs == 2);
ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs());

// Shape of boxes input should be
// Constant shape: [batch_size, num_boxes, num_classes, 4] or [batch_size, num_boxes, 1, 4]
// shareLocation == 0 or 1
// or
// Dynamic shape: some dimension values may be -1
ASSERT(inputs[0].nbDims == 4);

// Shape of scores input should be
// Constant shape: [batch_size, num_boxes, num_classes] or [batch_size, num_boxes, num_classes, 1]
// or
// Dynamic shape: some dimension values may be -1
ASSERT(inputs[1].nbDims == 3 || inputs[1].nbDims == 4);

if (inputs[0].d[0]->isConstant() && inputs[0].d[1]->isConstant() && inputs[0].d[2]->isConstant()
&& inputs[0].d[3]->isConstant())
{
boxesSize = exprBuilder
.operation(DimensionOperation::kPROD,
*exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[1], *inputs[0].d[2]),
*inputs[0].d[3])
->getConstantValue();
}

if (inputs[1].d[0]->isConstant() && inputs[1].d[1]->isConstant() && inputs[1].d[2]->isConstant())
{
scoresSize
= exprBuilder.operation(DimensionOperation::kPROD, *inputs[1].d[1], *inputs[1].d[2])->getConstantValue();
}

DimsExprs out_dim;
// num_detections
if (index == 0)
if (outputIndex == 0)
{
Dims dim0{};
dim0.nbDims = 0;
return dim0;
out_dim.nbDims = 2;
out_dim.d[0] = inputs[0].d[0];
out_dim.d[1] = exprBuilder.constant(1);
}
// nmsed_boxes
if (index == 1)
else if (outputIndex == 1)
{
out_dim.nbDims = 3;
out_dim.d[0] = inputs[0].d[0];
out_dim.d[1] = exprBuilder.constant(param.keepTopK);
out_dim.d[2] = exprBuilder.constant(4);
}
// nmsed_scores
else if (outputIndex == 2)
{
return DimsHW(param.keepTopK, 4);
out_dim.nbDims = 2;
out_dim.d[0] = inputs[0].d[0];
out_dim.d[1] = exprBuilder.constant(param.keepTopK);
}
// nmsed_scores or nmsed_classes
Dims dim1{};
dim1.nbDims = 1;
dim1.d[0] = param.keepTopK;
return dim1;
// nmsed_classes
else
{
out_dim.nbDims = 2;
out_dim.d[0] = inputs[0].d[0];
out_dim.d[1] = exprBuilder.constant(param.keepTopK);
}

return out_dim;
}

size_t BatchedNMSPlugin::getWorkspaceSize(int maxBatchSize) const
size_t BatchedNMSPlugin::getWorkspaceSize(
const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const
{
return detectionInferenceWorkspaceSize(param.shareLocation, maxBatchSize, boxesSize, scoresSize, param.numClasses,
numPriors, param.topK, DataType::kFLOAT, DataType::kFLOAT);
return detectionInferenceWorkspaceSize(param.shareLocation, inputs[0].dims.d[0], boxesSize, scoresSize,
param.numClasses, numPriors, param.topK, DataType::kFLOAT, DataType::kFLOAT);
}

int BatchedNMSPlugin::enqueue(
int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)
int BatchedNMSPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream)
{
const void* const locData = inputs[0];
const void* const confData = inputs[1];
Expand All @@ -109,7 +150,7 @@ int BatchedNMSPlugin::enqueue(
void* nmsedScores = outputs[2];
void* nmsedClasses = outputs[3];

pluginStatus_t status = nmsInference(stream, batchSize, boxesSize, scoresSize, param.shareLocation,
pluginStatus_t status = nmsInference(stream, inputDesc[0].dims.d[0], boxesSize, scoresSize, param.shareLocation,
param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold,
param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores,
nmsedClasses, workspace, param.isNormalized, false, mClipBoxes);
Expand All @@ -134,31 +175,46 @@ void BatchedNMSPlugin::serialize(void* buffer) const
ASSERT(d == a + getSerializationSize());
}

void BatchedNMSPlugin::configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs,
const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast,
const bool* outputIsBroadcast, nvinfer1::PluginFormat format, int maxBatchSize)
void BatchedNMSPlugin::configurePlugin(
const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs)
{
ASSERT(nbInputs == 2);
ASSERT(nbOutputs == 4);
ASSERT(inputDims[0].nbDims == 3);
ASSERT(inputDims[1].nbDims == 2 || (inputDims[1].nbDims == 3 && inputDims[1].d[2] == 1));
ASSERT(std::none_of(inputIsBroadcast, inputIsBroadcast + nbInputs, [](bool b) { return b; }));
ASSERT(std::none_of(outputIsBroadcast, outputIsBroadcast + nbInputs, [](bool b) { return b; }));

boxesSize = inputDims[0].d[0] * inputDims[0].d[1] * inputDims[0].d[2];
scoresSize = inputDims[1].d[0] * inputDims[1].d[1];
// num_boxes
numPriors = inputDims[0].d[0];
// Shape of boxes input should be
// Constant shape: [batch_size, num_boxes, num_classes, 4] or [batch_size, num_boxes, 1, 4]
// shareLocation == 0 or 1
const int numLocClasses = param.shareLocation ? 1 : param.numClasses;
// Third dimension of boxes must be either 1 or num_classes
ASSERT(inputDims[0].d[1] == numLocClasses);
ASSERT(inputDims[0].d[2] == 4);
ASSERT(in[0].desc.dims.nbDims == 4);
ASSERT(in[0].desc.dims.d[2] == numLocClasses);
ASSERT(in[0].desc.dims.d[3] == 4);

// Shape of scores input should be
// Constant shape: [batch_size, num_boxes, num_classes] or [batch_size, num_boxes, num_classes, 1]
ASSERT(in[1].desc.dims.nbDims == 3 || (in[1].desc.dims.nbDims == 4 && in[1].desc.dims.d[3] == 1));

boxesSize = in[0].desc.dims.d[1] * in[0].desc.dims.d[2] * in[0].desc.dims.d[3];
scoresSize = in[1].desc.dims.d[1] * in[1].desc.dims.d[2];
// num_boxes
numPriors = in[0].desc.dims.d[1];
}

bool BatchedNMSPlugin::supportsFormat(DataType type, PluginFormat format) const
bool BatchedNMSPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs)
{
return ((type == DataType::kFLOAT || type == DataType::kINT32) && format == PluginFormat::kNCHW);
ASSERT(0 <= pos && pos < 6);
const auto* in = inOut;
const auto* out = inOut + nbInputs;
switch (pos)
{
case 0: return in[0].type == DataType::kFLOAT && in[0].format == PluginFormat::kLINEAR;
case 1: return in[1].type == DataType::kFLOAT && in[1].format == PluginFormat::kLINEAR;
case 2: return out[0].type == DataType::kINT32 && out[0].format == PluginFormat::kLINEAR;
case 3: return out[1].type == DataType::kFLOAT && out[1].format == PluginFormat::kLINEAR;
case 4: return out[2].type == DataType::kFLOAT && out[2].format == PluginFormat::kLINEAR;
case 5: return out[3].type == DataType::kFLOAT && out[3].format == PluginFormat::kLINEAR;
}
}

const char* BatchedNMSPlugin::getPluginType() const
{
return NMS_PLUGIN_NAME;
Expand All @@ -174,7 +230,7 @@ void BatchedNMSPlugin::destroy()
delete this;
}

IPluginV2Ext* BatchedNMSPlugin::clone() const
IPluginV2DynamicExt* BatchedNMSPlugin::clone() const
{
auto* plugin = new BatchedNMSPlugin(param);
plugin->boxesSize = boxesSize;
Expand Down Expand Up @@ -210,16 +266,6 @@ void BatchedNMSPlugin::setClipParam(bool clip)
mClipBoxes = clip;
}

bool BatchedNMSPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
{
return false;
}

bool BatchedNMSPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
{
return false;
}

BatchedNMSPluginCreator::BatchedNMSPluginCreator()
: params{}
{
Expand Down Expand Up @@ -252,7 +298,7 @@ const PluginFieldCollection* BatchedNMSPluginCreator::getFieldNames()
return &mFC;
}

IPluginV2Ext* BatchedNMSPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
IPluginV2DynamicExt* BatchedNMSPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
const PluginField* fields = fc->fields;
mClipBoxes = true;
Expand Down Expand Up @@ -310,7 +356,8 @@ IPluginV2Ext* BatchedNMSPluginCreator::createPlugin(const char* name, const Plug
return plugin;
}

IPluginV2Ext* BatchedNMSPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
IPluginV2DynamicExt* BatchedNMSPluginCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength)
{
// This object will be deleted when the network is destroyed, which will
// call NMS::destroy()
Expand Down
29 changes: 13 additions & 16 deletions plugin/batchedNMSPlugin/batchedNMSPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace nvinfer1
namespace plugin
{

class BatchedNMSPlugin : public IPluginV2Ext
class BatchedNMSPlugin : public IPluginV2DynamicExt
{
public:
BatchedNMSPlugin(NMSParameters param);
Expand All @@ -39,45 +39,42 @@ class BatchedNMSPlugin : public IPluginV2Ext

int getNbOutputs() const override;

Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
DimsExprs getOutputDimensions(
int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) override;

int initialize() override;

void terminate() override;

size_t getWorkspaceSize(int maxBatchSize) const override;
size_t getWorkspaceSize(
const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override;

int enqueue(
int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
int enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) override;

size_t getSerializationSize() const override;

void serialize(void* buffer) const override;

void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs,
const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast,
const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) override;
void configurePlugin(
const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) override;

bool supportsFormat(DataType type, PluginFormat format) const override;
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override;

const char* getPluginType() const override;

const char* getPluginVersion() const override;

void destroy() override;

IPluginV2Ext* clone() const override;
IPluginV2DynamicExt* clone() const override;

nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, int nbInputs) const override;

void setPluginNamespace(const char* libNamespace) override;

const char* getPluginNamespace() const override;

bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;

bool canBroadcastInputAcrossBatch(int inputIndex) const override;

void setClipParam(bool clip);

private:
Expand All @@ -102,9 +99,9 @@ class BatchedNMSPluginCreator : public BaseCreator

const PluginFieldCollection* getFieldNames() override;

IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) override;
IPluginV2DynamicExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;

IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;

private:
static PluginFieldCollection mFC;
Expand Down

2 comments on commit 0e1638d

@roborocklsm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPluginV2DynamicExt implementation requires network without implicit batch dimension while previous version could handle implicit batch dimension. There is a compatibility issue here.

When I build the TensorRT with this commit, it calls ERROR on my program.

@rajeevsrao
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPluginV2DynamicExt implementation requires network without implicit batch dimension while previous version could handle implicit batch dimension. There is a compatibility issue here.

When I build the TensorRT with this commit, it calls ERROR on my program.

@roborocklsm Fixes in #738. Can you check and let me know if the new plugin (BatchedNMSDynamic_TRT) works?

Please sign in to comment.