Skip to content

Commit

Permalink
fix(): fixed interpolate_plugin to handle dynamically sized inputs fo…
Browse files Browse the repository at this point in the history
…r adaptive_pool2d

Signed-off-by: Abhiram Iyer <[email protected]>

Signed-off-by: Abhiram Iyer <[email protected]>
  • Loading branch information
abhi-iyer committed Jun 19, 2020
1 parent 549ca38 commit 7794c78
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,16 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
void *const *outputs, void *workspace,
cudaStream_t stream) {
at::Tensor input;

if (mode == "adaptive_pool2d") {
// use dynamically inferred input shape (for pooling)
input = at::from_blob((void*) inputs[0], util::toVec(inputDesc->dims), [](void*){}, tensor_options);
} else {
// use precomputed input shape (for interpolation/upsampling)
input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
}

at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);

at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
Expand Down

0 comments on commit 7794c78

Please sign in to comment.