Skip to content

Commit

Permalink
fix(aten::batchnorm|aten::view): Fix converter implementation for
Browse files Browse the repository at this point in the history
dynamic inputs

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 31, 2020
1 parent 736e914 commit bf651dd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
20 changes: 16 additions & 4 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));

torch::Tensor gamma, beta, mean, var;

if (ctx->input_is_dynamic) {
gamma = args[1].unwrapToTensor();
beta = args[2].unwrapToTensor();
mean = args[3].unwrapToTensor();
var = args[4].unwrapToTensor();
} else {
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
}

auto eps = args[7].unwrapToDouble(1e-5f);


LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
Expand Down
6 changes: 2 additions & 4 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace converters {
namespace impl {
namespace {

static auto shuffle_registrations = RegisterNodeConversionPatterns()
static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down Expand Up @@ -50,12 +50,10 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto ex_tensor = torch::rand(in_shape);
auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(new_shape));
shuffle->setReshapeDimensions(util::toDims(args[1].unwrapToIntList().vec()));
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
Expand Down

0 comments on commit bf651dd

Please sign in to comment.