diff --git a/plugin/batchedNMSPlugin/batchedNMSPlugin.cpp b/plugin/batchedNMSPlugin/batchedNMSPlugin.cpp index ee0c8b93..ec015ee0 100644 --- a/plugin/batchedNMSPlugin/batchedNMSPlugin.cpp +++ b/plugin/batchedNMSPlugin/batchedNMSPlugin.cpp @@ -63,43 +63,84 @@ int BatchedNMSPlugin::initialize() void BatchedNMSPlugin::terminate() {} -Dims BatchedNMSPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) +DimsExprs BatchedNMSPlugin::getOutputDimensions( + int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) { - ASSERT(nbInputDims == 2); - ASSERT(index >= 0 && index < this->getNbOutputs()); - ASSERT(inputs[0].nbDims == 3); - ASSERT(inputs[1].nbDims == 2 || (inputs[1].nbDims == 3 && inputs[1].d[2] == 1)); - // boxesSize: number of box coordinates for one sample - boxesSize = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2]; - // scoresSize: number of scores for one sample - scoresSize = inputs[1].d[0] * inputs[1].d[1]; + ASSERT(nbInputs == 2); + ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); + + // Shape of boxes input should be + // Constant shape: [batch_size, num_boxes, num_classes, 4] or [batch_size, num_boxes, 1, 4] + // shareLocation == 0 or 1 + // or + // Dynamic shape: some dimension values may be -1 + ASSERT(inputs[0].nbDims == 4); + + // Shape of scores input should be + // Constant shape: [batch_size, num_boxes, num_classes] or [batch_size, num_boxes, num_classes, 1] + // or + // Dynamic shape: some dimension values may be -1 + ASSERT(inputs[1].nbDims == 3 || inputs[1].nbDims == 4); + + if (inputs[0].d[0]->isConstant() && inputs[0].d[1]->isConstant() && inputs[0].d[2]->isConstant() + && inputs[0].d[3]->isConstant()) + { + boxesSize = exprBuilder + .operation(DimensionOperation::kPROD, + *exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[1], *inputs[0].d[2]), + *inputs[0].d[3]) + ->getConstantValue(); + } + + if (inputs[1].d[0]->isConstant() && inputs[1].d[1]->isConstant() && inputs[1].d[2]->isConstant()) + { + scoresSize + = exprBuilder.operation(DimensionOperation::kPROD, *inputs[1].d[1], *inputs[1].d[2])->getConstantValue(); + } + + DimsExprs out_dim; // num_detections - if (index == 0) + if (outputIndex == 0) { - Dims dim0{}; - dim0.nbDims = 0; - return dim0; + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(1); } // nmsed_boxes - if (index == 1) + else if (outputIndex == 1) + { + out_dim.nbDims = 3; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(param.keepTopK); + out_dim.d[2] = exprBuilder.constant(4); + } + // nmsed_scores + else if (outputIndex == 2) { - return DimsHW(param.keepTopK, 4); + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(param.keepTopK); } - // nmsed_scores or nmsed_classes - Dims dim1{}; - dim1.nbDims = 1; - dim1.d[0] = param.keepTopK; - return dim1; + // nmsed_classes + else + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(param.keepTopK); + } + + return out_dim; } -size_t BatchedNMSPlugin::getWorkspaceSize(int maxBatchSize) const +size_t BatchedNMSPlugin::getWorkspaceSize( + const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const { - return detectionInferenceWorkspaceSize(param.shareLocation, maxBatchSize, boxesSize, scoresSize, param.numClasses, - numPriors, param.topK, DataType::kFLOAT, DataType::kFLOAT); + return detectionInferenceWorkspaceSize(param.shareLocation, inputs[0].dims.d[0], boxesSize, scoresSize, + param.numClasses, numPriors, param.topK, DataType::kFLOAT, DataType::kFLOAT); } -int BatchedNMSPlugin::enqueue( - int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) +int BatchedNMSPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { const void* const locData = inputs[0]; const void* const confData = inputs[1]; @@ -109,7 +150,7 @@ int BatchedNMSPlugin::enqueue( void* nmsedScores = outputs[2]; void* nmsedClasses = outputs[3]; - pluginStatus_t status = nmsInference(stream, batchSize, boxesSize, scoresSize, param.shareLocation, + pluginStatus_t status = nmsInference(stream, inputDesc[0].dims.d[0], boxesSize, scoresSize, param.shareLocation, param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold, param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores, nmsedClasses, workspace, param.isNormalized, false, mClipBoxes); @@ -134,31 +175,46 @@ void BatchedNMSPlugin::serialize(void* buffer) const ASSERT(d == a + getSerializationSize()); } -void BatchedNMSPlugin::configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, - const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, - const bool* outputIsBroadcast, nvinfer1::PluginFormat format, int maxBatchSize) +void BatchedNMSPlugin::configurePlugin( + const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) { ASSERT(nbInputs == 2); ASSERT(nbOutputs == 4); - ASSERT(inputDims[0].nbDims == 3); - ASSERT(inputDims[1].nbDims == 2 || (inputDims[1].nbDims == 3 && inputDims[1].d[2] == 1)); - ASSERT(std::none_of(inputIsBroadcast, inputIsBroadcast + nbInputs, [](bool b) { return b; })); - ASSERT(std::none_of(outputIsBroadcast, outputIsBroadcast + nbInputs, [](bool b) { return b; })); - boxesSize = inputDims[0].d[0] * inputDims[0].d[1] * inputDims[0].d[2]; - scoresSize = inputDims[1].d[0] * inputDims[1].d[1]; - // num_boxes - numPriors = inputDims[0].d[0]; + // Shape of boxes input should be + // Constant shape: [batch_size, num_boxes, num_classes, 4] or [batch_size, num_boxes, 1, 4] + // shareLocation == 0 or 1 const int numLocClasses = param.shareLocation ? 1 : param.numClasses; - // Third dimension of boxes must be either 1 or num_classes - ASSERT(inputDims[0].d[1] == numLocClasses); - ASSERT(inputDims[0].d[2] == 4); + ASSERT(in[0].desc.dims.nbDims == 4); + ASSERT(in[0].desc.dims.d[2] == numLocClasses); + ASSERT(in[0].desc.dims.d[3] == 4); + + // Shape of scores input should be + // Constant shape: [batch_size, num_boxes, num_classes] or [batch_size, num_boxes, num_classes, 1] + ASSERT(in[1].desc.dims.nbDims == 3 || (in[1].desc.dims.nbDims == 4 && in[1].desc.dims.d[3] == 1)); + + boxesSize = in[0].desc.dims.d[1] * in[0].desc.dims.d[2] * in[0].desc.dims.d[3]; + scoresSize = in[1].desc.dims.d[1] * in[1].desc.dims.d[2]; + // num_boxes + numPriors = in[0].desc.dims.d[1]; } -bool BatchedNMSPlugin::supportsFormat(DataType type, PluginFormat format) const +bool BatchedNMSPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) { - return ((type == DataType::kFLOAT || type == DataType::kINT32) && format == PluginFormat::kNCHW); + ASSERT(0 <= pos && pos < 6); + const auto* in = inOut; + const auto* out = inOut + nbInputs; + switch (pos) + { + case 0: return in[0].type == DataType::kFLOAT && in[0].format == PluginFormat::kLINEAR; + case 1: return in[1].type == DataType::kFLOAT && in[1].format == PluginFormat::kLINEAR; + case 2: return out[0].type == DataType::kINT32 && out[0].format == PluginFormat::kLINEAR; + case 3: return out[1].type == DataType::kFLOAT && out[1].format == PluginFormat::kLINEAR; + case 4: return out[2].type == DataType::kFLOAT && out[2].format == PluginFormat::kLINEAR; + case 5: return out[3].type == DataType::kFLOAT && out[3].format == PluginFormat::kLINEAR; + } } + const char* BatchedNMSPlugin::getPluginType() const { return NMS_PLUGIN_NAME; @@ -174,7 +230,7 @@ void BatchedNMSPlugin::destroy() delete this; } -IPluginV2Ext* BatchedNMSPlugin::clone() const +IPluginV2DynamicExt* BatchedNMSPlugin::clone() const { auto* plugin = new BatchedNMSPlugin(param); plugin->boxesSize = boxesSize; @@ -210,16 +266,6 @@ void BatchedNMSPlugin::setClipParam(bool clip) mClipBoxes = clip; } -bool BatchedNMSPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const -{ - return false; -} - -bool BatchedNMSPlugin::canBroadcastInputAcrossBatch(int inputIndex) const -{ - return false; -} - BatchedNMSPluginCreator::BatchedNMSPluginCreator() : params{} { @@ -252,7 +298,7 @@ const PluginFieldCollection* BatchedNMSPluginCreator::getFieldNames() return &mFC; } -IPluginV2Ext* BatchedNMSPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) +IPluginV2DynamicExt* BatchedNMSPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) { const PluginField* fields = fc->fields; mClipBoxes = true; @@ -310,7 +356,8 @@ IPluginV2Ext* BatchedNMSPluginCreator::createPlugin(const char* name, const Plug return plugin; } -IPluginV2Ext* BatchedNMSPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) +IPluginV2DynamicExt* BatchedNMSPluginCreator::deserializePlugin( + const char* name, const void* serialData, size_t serialLength) { // This object will be deleted when the network is destroyed, which will // call NMS::destroy() diff --git a/plugin/batchedNMSPlugin/batchedNMSPlugin.h b/plugin/batchedNMSPlugin/batchedNMSPlugin.h index ca0aaf13..7dfd1707 100644 --- a/plugin/batchedNMSPlugin/batchedNMSPlugin.h +++ b/plugin/batchedNMSPlugin/batchedNMSPlugin.h @@ -28,7 +28,7 @@ namespace nvinfer1 namespace plugin { -class BatchedNMSPlugin : public IPluginV2Ext +class BatchedNMSPlugin : public IPluginV2DynamicExt { public: BatchedNMSPlugin(NMSParameters param); @@ -39,26 +39,27 @@ class BatchedNMSPlugin : public IPluginV2Ext int getNbOutputs() const override; - Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; + DimsExprs getOutputDimensions( + int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) override; int initialize() override; void terminate() override; - size_t getWorkspaceSize(int maxBatchSize) const override; + size_t getWorkspaceSize( + const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override; - int enqueue( - int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override; + int enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) override; size_t getSerializationSize() const override; void serialize(void* buffer) const override; - void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, - const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, - const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) override; + void configurePlugin( + const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) override; - bool supportsFormat(DataType type, PluginFormat format) const override; + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override; const char* getPluginType() const override; @@ -66,7 +67,7 @@ class BatchedNMSPlugin : public IPluginV2Ext void destroy() override; - IPluginV2Ext* clone() const override; + IPluginV2DynamicExt* clone() const override; nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, int nbInputs) const override; @@ -74,10 +75,6 @@ class BatchedNMSPlugin : public IPluginV2Ext const char* getPluginNamespace() const override; - bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override; - - bool canBroadcastInputAcrossBatch(int inputIndex) const override; - void setClipParam(bool clip); private: @@ -102,9 +99,9 @@ class BatchedNMSPluginCreator : public BaseCreator const PluginFieldCollection* getFieldNames() override; - IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) override; + IPluginV2DynamicExt* createPlugin(const char* name, const PluginFieldCollection* fc) override; - IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; + IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; private: static PluginFieldCollection mFC;