Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Upgrade TensorRT version to TRT 10 EA #2699

Merged
merged 80 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
cd86660
feat: Add save API for torch-trt compiled models
peri044 Mar 14, 2024
3ece71b
chore: resolve merge conflicts
peri044 Mar 15, 2024
eab0dba
chore: Fix save failures
peri044 Mar 18, 2024
b191d62
chore: update to 2.3 rc build
peri044 Mar 18, 2024
ce606fe
chore: rebase with release/2.3 branch
peri044 Mar 19, 2024
8674a3c
chore: minor fixes
peri044 Mar 19, 2024
f4e8fe9
chore: remove duplicate bert test case
peri044 Mar 20, 2024
4ae6ab9
chore: remove comments
peri044 Mar 20, 2024
fff1b80
chore: Upgrade to TRT 10.0
peri044 Mar 12, 2024
39ca77d
chore: more fixes
peri044 Mar 21, 2024
5431ee3
chore: update trt version
peri044 Mar 25, 2024
0c03de5
chore: more updates
peri044 Mar 26, 2024
982dbd2
parent f39e89e3964bc3d6ea3a6989b1e4099e1bb3e6dd
peri044 Mar 25, 2024
1ae46e9
chore: more updates
peri044 Mar 27, 2024
ae87fba
chore: rebase with save
peri044 Mar 27, 2024
beb5920
chore: Update versions
peri044 Mar 27, 2024
f0068c6
chore: update tensorrt version in CI
peri044 Mar 27, 2024
39261b9
chore: more updates
peri044 Mar 27, 2024
3753150
chore: more fixes
peri044 Apr 2, 2024
16a191c
Merge branch 'release/2.3' into trt_10
peri044 Apr 2, 2024
c355766
chore: remove NvUtils.h
peri044 Apr 2, 2024
2d237dc
chore: more updates
peri044 Apr 2, 2024
e4b4429
chore: change lib64 to lib in rhel BUILD file
peri044 Apr 2, 2024
fa4fb9c
chore: more updates
peri044 Apr 2, 2024
e11eb60
chore: fix TRT version
peri044 Apr 2, 2024
092feb2
chore: more updates
peri044 Apr 2, 2024
09ecf26
fix shape bug in bitwise ops
zewenli98 Apr 3, 2024
85e04c5
chore: update to rhel9
peri044 Apr 3, 2024
6a3664e
Merge branch 'trt_10' of github.com:pytorch/TensorRT into trt_10
peri044 Apr 3, 2024
41229d6
chore: change trt version
peri044 Apr 3, 2024
9d7a656
fix test bug and add more tests
zewenli98 Apr 3, 2024
5e911a9
chore: delete mirror of rules_pkg
peri044 Apr 3, 2024
dae0eb2
chore: fix conv test
peri044 Apr 3, 2024
2a32b13
Merge branch 'trt_10' of github.com:pytorch/TensorRT into trt_10
peri044 Apr 3, 2024
4676cd2
chore: fix trt version range
peri044 Apr 4, 2024
88efe8e
chore: fix trt rangfe
peri044 Apr 4, 2024
f9b40e6
chore: minor fix
peri044 Apr 4, 2024
b86aec2
chore: update rules_pkg
peri044 Apr 4, 2024
6630281
chore: minor fixes
peri044 Apr 4, 2024
fca55fe
chore: expt
peri044 Apr 4, 2024
1ca01e7
chore: update WORKSPACE tmpl
peri044 Apr 5, 2024
cdf5d07
chore: rebase with 2.3
peri044 Apr 5, 2024
6ffb85e
chore: fix
peri044 Apr 5, 2024
76af510
chore: remove cudnn dep
peri044 Apr 6, 2024
f9cf75a
chore: fix
peri044 Apr 6, 2024
33ba8b2
chore: updates
peri044 Apr 8, 2024
923377c
chore: update post-build script
peri044 Apr 9, 2024
89f04db
chore: remove trt dep
peri044 Apr 9, 2024
7620acc
chore: updates
peri044 Apr 9, 2024
62332fb
chore: set ld_library path in post script
peri044 Apr 9, 2024
96a8bf6
chore: updates
peri044 Apr 9, 2024
041f6a3
chore: updates
peri044 Apr 9, 2024
83e9a0b
chore: disable smoke test
peri044 Apr 9, 2024
e8529b0
chore: updates
peri044 Apr 9, 2024
1357112
chore: updates
peri044 Apr 9, 2024
608a6d2
chore: updates
peri044 Apr 10, 2024
1b34b32
chore: updates
peri044 Apr 10, 2024
89cb55a
chore: updates
peri044 Apr 10, 2024
4323e36
chore: updates
peri044 Apr 10, 2024
60b3e51
chore: update hw_compat
peri044 Apr 10, 2024
05627cd
chore: updates
peri044 Apr 12, 2024
d16585f
chore: update streams
peri044 Apr 12, 2024
16088e6
chore: updates
peri044 Apr 12, 2024
3d149ef
chore: updates
peri044 Apr 13, 2024
3addcae
chore: updates
peri044 Apr 13, 2024
b0e92d8
chore: update hw_compat.ts
peri044 Apr 15, 2024
d285d27
fix dynamic shape bugs for test_binary_ops_aten
zewenli98 Apr 15, 2024
d78a846
chore: revert layer_norm test
peri044 Apr 16, 2024
ba8a424
chore: rebase
peri044 Apr 16, 2024
097d887
Merge branch 'trt_10' of github.com:pytorch/TensorRT into trt_10
zewenli98 Apr 16, 2024
ffe7a52
chore: rebase with release/2.3
peri044 Apr 18, 2024
38642bb
chore: updates
peri044 Apr 18, 2024
d15dd72
chore: updates
peri044 Apr 19, 2024
dee9aa0
chore: updates
peri044 Apr 19, 2024
c05d675
chore: update stream in python runtime
peri044 Apr 19, 2024
2329657
chore: update hw_compat.ts
peri044 Apr 19, 2024
b8a8709
chore: updates
peri044 Apr 20, 2024
44778e1
chore: updates
peri044 Apr 20, 2024
0dbbcd7
chore: updates
peri044 Apr 20, 2024
89c3d76
chore: updates
peri044 Apr 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/scripts/install-torch-tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ source ${BUILD_ENV_FILE}
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision
${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0
export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com

# Install TensorRT manually
wget -q -P /opt/torch-tensorrt-builds/ https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.0/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
tar -xzf /opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -C /opt/torch-tensorrt-builds/
python -m pip install /opt/torch-tensorrt-builds/TensorRT-10.0.0.6/python/tensorrt-10.0.0b6-cp${PYTHON_VERSION//./}-none-linux_x86_64.whl

# Install Torch-TensorRT
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl

echo -e "Running test script";
36 changes: 25 additions & 11 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ on:

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@release/2.3
with:
package-type: wheel
os: linux
Expand All @@ -37,11 +37,11 @@ jobs:
- repository: pytorch/tensorrt
pre-script: packaging/pre_build_script.sh
env-var-script: packaging/env_vars.txt
post-script: ""
smoke-test-script: ""
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
package-name: torch_tensorrt
name: Build torch-tensorrt whl package
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@release/2.3
with:
repository: ${{ matrix.repository }}
ref: ""
Expand All @@ -65,7 +65,8 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-torchscript-fe
repository: "pytorch/tensorrt"
Expand All @@ -77,9 +78,11 @@ jobs:
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/modules
${CONDA_RUN} python -m pip install --pre -r requirements.txt --use-deprecated=legacy-resolver
# Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now.
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2
${CONDA_RUN} python hub.py
popd
pushd .
Expand All @@ -100,7 +103,8 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-converters
repository: "pytorch/tensorrt"
Expand All @@ -111,6 +115,7 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
Expand All @@ -127,7 +132,8 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-fe
repository: "pytorch/tensorrt"
Expand All @@ -138,6 +144,7 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
Expand All @@ -155,7 +162,8 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-serde
repository: "pytorch/tensorrt"
Expand All @@ -166,6 +174,7 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
Expand All @@ -182,7 +191,8 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-torch-compile-be
repository: "pytorch/tensorrt"
Expand All @@ -193,6 +203,7 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
Expand All @@ -211,7 +222,8 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-core
repository: "pytorch/tensorrt"
Expand All @@ -222,6 +234,7 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
Expand Down Expand Up @@ -251,6 +264,7 @@ jobs:
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
pushd .
cd tests/py/core
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.

- Bazel 5.2.0
- Libtorch 2.3.0.dev (latest nightly) (built with CUDA 12.1)
- Libtorch 2.3.0 (built with CUDA 12.1)
- CUDA 12.1
- cuDNN 8.9.5
- TensorRT 8.6.1
- TensorRT 10.0.0.6

## Prebuilt Binaries and Wheel files

Expand Down
9 changes: 7 additions & 2 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ nvinfer1::ITensor* addPadding(
}
}

nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name) {
nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0);
input_shape = castITensor(ctx, input_shape, nvinfer1::DataType::kINT32, name);
return input_shape;
}

nvinfer1::ITensor* addUnpadding(
ConversionCtx* ctx,
const torch::jit::Node* n,
Expand Down Expand Up @@ -134,7 +140,7 @@ nvinfer1::ILayer* add_elementwise(
}
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
nvinfer1::ITensor* selfShape = getShapeOutput(ctx, self, std::string(name + "_shape_cast").c_str());
// size of dynamic dimension of other need to the same as that of
// corresponding dimension of self
auto otherDynamicShape =
Expand Down Expand Up @@ -348,7 +354,6 @@ nvinfer1::ITensor* normalize_indices(
auto neg_itensor = tensor_to_const(ctx, neg);
// find the indices that = -1
auto signs = clamp(ctx, indices, neg_itensor, zero_itensor, "clamp layer for " + name);

// get the inputDim value where indices == -1, else 0
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, signs, input_dim, "prod layer for " + name);
TORCHTRT_CHECK(mul, "Unable to create mul layer in normalize_indices");
Expand Down
3 changes: 3 additions & 0 deletions core/conversion/converters/converter_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ nvinfer1::ITensor* castITensor(
nvinfer1::DataType dtype,
const std::string& layer_name_prefix = "");

// Get the shape of the input tensor and cast it to INT32 type
nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name = "");

// Freeze an at::Tensor in a IConstant layer
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string());

Expand Down
4 changes: 0 additions & 4 deletions core/conversion/converters/impl/chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
auto chunks = args[1].unwrapToInt();
auto dim = args[2].unwrapToInt();
bool dynamic_shape = ctx->input_is_dynamic;
int size = in->getDimensions().nbDims;
int maxDim = static_cast<int32_t>(in->getDimensions().d[dim]);

c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
Expand All @@ -41,9 +40,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
size_.nbDims = nbdims;
stride_.nbDims = nbdims;

int startIdx = 0;
int endIdx = maxDim;

for (int i = 0; i < nbdims; i++) {
start_.d[i] = 0;
size_.d[i] = 0;
Expand Down
11 changes: 4 additions & 7 deletions core/conversion/converters/impl/constant_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,15 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
util::toDims(c10::IntArrayRef(stride)));
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
slice_layer->setName((util::node_info(n) + "_slice").c_str());
slice_layer->setMode(nvinfer1::SliceMode::kFILL);
slice_layer->setMode(nvinfer1::SampleMode::kFILL);
slice_layer->setInput(4, *value_itensor);

if (ctx->input_is_dynamic) {
// build the size using inetwork layers
auto shape_layer = ctx->net->addShape(*in);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
shape_layer->setName((util::node_info(n) + "_shape").c_str());
auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32));

auto add_layer = ctx->net->addElementWise(
*shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM);
nvinfer1::ITensor* shapeOutput = getShapeOutput(ctx, in, (util::node_info(n) + "_shape").c_str());
auto add_layer =
ctx->net->addElementWise(*shapeOutput, *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM);
TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n);
add_layer->setName((util::node_info(n) + "_add").c_str());
slice_layer->setInput(2, *add_layer->getOutput(0));
Expand Down
14 changes: 6 additions & 8 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ nvinfer1::ILayer* add_bias_layer(
nvinfer1::Dims& input_dims,
nvinfer1::Dims& output_padding,
Weights& bias) {
nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0);
nvinfer1::ITensor* input_shape = getShapeOutput(ctx, input_tensor, std::string("bias_shape_cast").c_str());
// Add padding layer
nvinfer1::ITensor* start;
nvinfer1::ITensor* totalPadding;
Expand Down Expand Up @@ -61,7 +61,7 @@ nvinfer1::ILayer* add_bias_layer(
auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride);
sliceLayer->setInput(1, *start);
sliceLayer->setInput(2, *size);
sliceLayer->setMode(nvinfer1::SliceMode::kFILL);
sliceLayer->setMode(nvinfer1::SampleMode::kFILL);
nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0);

nvinfer1::Dims constantDims;
Expand Down Expand Up @@ -146,9 +146,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
// TensorRT expects nbSpatialDims = 2 or 3
filter_dim = util::unsqueezeDims(filter_dim, filter_dim.nbDims, 1, false);
// Reshape input dimensions
in = addPadding(ctx, n, in, 4);
in = addPadding(ctx, n, in, 4, true, true, std::string(util::node_info(n) + "_input_shuffle"));
LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions());
kernel = addPadding(ctx, n, kernel, 4);
kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle"));
LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions());
if (transposed) {
num_output_maps = kernel_dims.d[1];
Expand Down Expand Up @@ -194,7 +194,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
nvinfer1::IConvolutionLayer* convLayer =
ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data);
convLayer->setStrideNd(stride);
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
convLayer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN);
convLayer->setPaddingNd(padding);
convLayer->setPostPadding(out_padding);
convLayer->setDilationNd(dilation);
Expand Down Expand Up @@ -291,11 +291,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
// shape of convolution's weight: [out, in/groups, ...]
auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data);
TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n);

conv->setStrideNd(stride);
conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
conv->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN);
conv->setPaddingNd(padding);
conv->setPostPadding(out_padding);
conv->setDilationNd(dilation);
conv->setNbGroups(groups);
new_layer = conv;
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/cumsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
torch::Tensor axis = torch::tensor(input_dims.d[dim], torch::kInt32);
tripLimit = tensor_to_const(ctx, axis);
} else {
nvinfer1::ITensor* inpShape = ctx->net->addShape(*in)->getOutput(0);
nvinfer1::ITensor* inpShape = getShapeOutput(ctx, in);
torch::Tensor dimValue = torch::tensor(dim, torch::kInt32);
nvinfer1::ITensor* axis = tensor_to_const(ctx, dimValue);
tripLimit = ctx->net->addGather(*inpShape, *axis, 0)->getOutput(0);
Expand Down
9 changes: 4 additions & 5 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe
if (max_rank - old_rank > 0) {
torch::Tensor thOne = torch::tensor(std::vector<int32_t>(max_rank - old_rank, 1), torch::kInt32);
auto one_tensor = tensor_to_const(ctx, thOne);
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
auto in_shape_tensor = getShapeOutput(ctx, tensor);
nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor};
return ctx->net->addConcatenation(args, 2)->getOutput(0);
} else { // max_rank - old_rank == 0
return ctx->net->addShape(*tensor)->getOutput(0);
return getShapeOutput(ctx, tensor);
}
}

Expand Down Expand Up @@ -221,8 +221,7 @@ auto expand_registrations TORCHTRT_UNUSED =
auto targetDims = targetTensor->getDimensions();
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
if (ctx->input_is_dynamic) {
return add_expand_dynamic(
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false);
} else {
return add_expand(ctx, n, in, targetDims);
}
Expand Down Expand Up @@ -357,7 +356,7 @@ auto expand_registrations TORCHTRT_UNUSED =
if (ctx->input_is_dynamic) {
auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32));

auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0);
auto expand_output_shape = getShapeOutput(ctx, expand->getOutput(0));
std::vector<int64_t> repeat_const_vec(repeat_shape_dims.nbDims, 1);
repeat_const_vec[dim + 1] = repeats;
auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32));
Expand Down
Loading
Loading