Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark negative indices support for gather as optional #681

Merged
merged 1 commit into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ if(BUILD_ONNXIFI)
set(ONNXIFI_SOURCES onnx_trt_backend.cpp)
endif()

# Build with negative indices support for Gather:
if (DEFINED SUPPORT_NEGATIVE_GATHER)
add_definitions("-DSUPPORT_NEGATIVE_GATHER=1")
endif()

# Build executables if BUILD_LIBRARY_ONLY flag is not set
if (NOT DEFINED BUILD_LIBRARY_ONLY)
set(EXECUTABLE_SOURCES
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Once you have cloned the repository, you can build the parser libraries and exec
// Ensure that you update your LD_LIBRARY_PATH to pick up the location of the newly built library:
export LD_LIBRARY_PATH=$PWD:$LD_LIBRARY_PATH

For building only the libraries, append `-DBUILD_LIBRARY_ONLY=1` to the CMake build command.
For building only the libraries, append `-DBUILD_LIBRARY_ONLY=1` to the CMake build command. If your model has Gather or GatherElements operations with negative indices, add `-DSUPPORT_NEGATIVE_GATHER` to the build command. Note that enabling negative-indices gather will have a performance impact on gathers with non-negative indices.

## Executable Usage

Expand Down
10 changes: 8 additions & 2 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,11 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather)
TRT_CHECK(convertAxis(axis, nbDims));
LOG_VERBOSE("Using Gather axis: " << axis);

// Convert any negative indices to positive ones
// Support for negative indices can be enabled through adding -DSUPPORT_NEGATIVE_GATHER=1 in the CMake build command.
// This will unnecessarily reduce performance of networks that use only non-negative Gather indices.
#if SUPPORT_NEGATIVE_GATHER
indices = convertGatherIndices(ctx, data, indices, axis);
#endif // SUPPORT_NEGATIVE_GATHER

auto* layer = ctx->network()->addGather(*data, *indices, axis);
ctx->registerLayer(layer, getNodeName(node));
Expand Down Expand Up @@ -1251,8 +1254,11 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
int32_t axis = attrs.get<int32_t>("axis", 0);
int32_t dataNbDims = daDims.nbDims;

// Convert any negative indices to positive ones
// Support for negative indices can be enabled through adding -DSUPPORT_NEGATIVE_GATHER=1 in the CMake build command.
// This will unnecessarily reduce performance of networks that use only non-negative Gather indices.
#if SUPPORT_NEGATIVE_GATHER
index = convertGatherIndices(ctx, data, index, axis);
#endif // SUPPORT_NEGATIVE_GATHER

TRT_CHECK(convertAxis(axis, dataNbDims));
LOG_VERBOSE("Using Gather axis: " << axis);
Expand Down
22 changes: 14 additions & 8 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,21 @@ nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* da

nvinfer1::ITensor* convertGatherIndices(IImporterContext* ctx, nvinfer1::ITensor* data, nvinfer1::ITensor* indices, int32_t axis)
{
// Create a condition tensor that is 1 for the elements in indices that are < 0 or 0 otherwise
auto condition = ctx->network()->addElementWise(*indices, *createZeroTensor(ctx, indices), nvinfer1::ElementWiseOperation::kLESS)->getOutput(0);
const int32_t n = indices->getDimensions().nbDims;
auto axisLength = getAxisLength(ctx, data, axis);
broadcastTensors(ctx, axisLength, indices);
// Create a shifted tensor that is indices + axisLength
auto shifted = ctx->network()->addElementWise(*indices, *axisLength, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
// Select between the shifted and original data based on condition
auto select = ctx->network()->addSelect(*condition, *shifted, *indices);
return select->getOutput(0);
broadcastTensor(ctx, axisLength, n);

// The formula here implements "indices < 0 ? indices + axisLength : indices"
// via the formula "indices - axisLength * max(-1, min(0, indices))".
// Think of the "max(-1, min(0, indices))" as extracting the sign bit from the indices.
const nvinfer1::Dims d = makeDims(n, 1);
auto zero = addConstantScalar(ctx, 0, ::ONNX_NAMESPACE::TensorProto::INT32, d)->getOutput(0);
auto minusOne = addConstantScalar(ctx, -1, ::ONNX_NAMESPACE::TensorProto::INT32, d)->getOutput(0);
auto min = ctx->network()->addElementWise(*zero, *indices, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
auto max = ctx->network()->addElementWise(*minusOne, *min, nvinfer1::ElementWiseOperation::kMAX)->getOutput(0);
auto prod = ctx->network()->addElementWise(*max, *axisLength, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
auto sub = ctx->network()->addElementWise(*indices, *prod, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
return sub;
}

template <typename DataType>
Expand Down