Skip to content

Commit

Permalink
Mark negative indices support for gather as optional (#681)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv committed Jul 2, 2021
1 parent 567503d commit dd15315
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ if(BUILD_ONNXIFI)
set(ONNXIFI_SOURCES onnx_trt_backend.cpp)
endif()

# Build with negative indices support for Gather:
if (DEFINED SUPPORT_NEGATIVE_GATHER)
add_definitions("-DSUPPORT_NEGATIVE_GATHER=1")
endif()

# Build executables if BUILD_LIBRARY_ONLY flag is not set
if (NOT DEFINED BUILD_LIBRARY_ONLY)
set(EXECUTABLE_SOURCES
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Once you have cloned the repository, you can build the parser libraries and exec
// Ensure that you update your LD_LIBRARY_PATH to pick up the location of the newly built library:
export LD_LIBRARY_PATH=$PWD:$LD_LIBRARY_PATH

For building only the libraries, append `-DBUILD_LIBRARY_ONLY=1` to the CMake build command.
For building only the libraries, append `-DBUILD_LIBRARY_ONLY=1` to the CMake build command. If your model has Gather or GatherElements operations with negative indices, add `-DSUPPORT_NEGATIVE_GATHER` to the build command. Note that enabling negative-indices gather will have a performance impact on gathers with non-negative indices.

## Executable Usage

Expand Down
10 changes: 8 additions & 2 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,11 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather)
TRT_CHECK(convertAxis(axis, nbDims));
LOG_VERBOSE("Using Gather axis: " << axis);

// Convert any negative indices to positive ones
// Support for negative indices can be enabled through adding -DSUPPORT_NEGATIVE_GATHER=1 in the CMake build command.
// This will unnecessarily reduce performance of networks that use only non-negative Gather indices.
#if SUPPORT_NEGATIVE_GATHER
indices = convertGatherIndices(ctx, data, indices, axis);
#endif // SUPPORT_NEGATIVE_GATHER

auto* layer = ctx->network()->addGather(*data, *indices, axis);
ctx->registerLayer(layer, getNodeName(node));
Expand Down Expand Up @@ -1251,8 +1254,11 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
int32_t axis = attrs.get<int32_t>("axis", 0);
int32_t dataNbDims = daDims.nbDims;

// Convert any negative indices to positive ones
// Support for negative indices can be enabled through adding -DSUPPORT_NEGATIVE_GATHER=1 in the CMake build command.
// This will unnecessarily reduce performance of networks that use only non-negative Gather indices.
#if SUPPORT_NEGATIVE_GATHER
index = convertGatherIndices(ctx, data, index, axis);
#endif // SUPPORT_NEGATIVE_GATHER

TRT_CHECK(convertAxis(axis, dataNbDims));
LOG_VERBOSE("Using Gather axis: " << axis);
Expand Down
22 changes: 14 additions & 8 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,21 @@ nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* da

nvinfer1::ITensor* convertGatherIndices(IImporterContext* ctx, nvinfer1::ITensor* data, nvinfer1::ITensor* indices, int32_t axis)
{
// Create a condition tensor that is 1 for the elements in indices that are < 0 or 0 otherwise
auto condition = ctx->network()->addElementWise(*indices, *createZeroTensor(ctx, indices), nvinfer1::ElementWiseOperation::kLESS)->getOutput(0);
const int32_t n = indices->getDimensions().nbDims;
auto axisLength = getAxisLength(ctx, data, axis);
broadcastTensors(ctx, axisLength, indices);
// Create a shifted tensor that is indices + axisLength
auto shifted = ctx->network()->addElementWise(*indices, *axisLength, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
// Select between the shifted and original data based on condition
auto select = ctx->network()->addSelect(*condition, *shifted, *indices);
return select->getOutput(0);
broadcastTensor(ctx, axisLength, n);

// The formula here implements "indices < 0 ? indices + axisLength : indices"
// via the formula "indices - axisLength * max(-1, min(0, indices))".
// Think of the "max(-1, min(0, indices))" as extracting the sign bit from the indices.
const nvinfer1::Dims d = makeDims(n, 1);
auto zero = addConstantScalar(ctx, 0, ::ONNX_NAMESPACE::TensorProto::INT32, d)->getOutput(0);
auto minusOne = addConstantScalar(ctx, -1, ::ONNX_NAMESPACE::TensorProto::INT32, d)->getOutput(0);
auto min = ctx->network()->addElementWise(*zero, *indices, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
auto max = ctx->network()->addElementWise(*minusOne, *min, nvinfer1::ElementWiseOperation::kMAX)->getOutput(0);
auto prod = ctx->network()->addElementWise(*max, *axisLength, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
auto sub = ctx->network()->addElementWise(*indices, *prod, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
return sub;
}

template <typename DataType>
Expand Down

0 comments on commit dd15315

Please sign in to comment.