Skip to content

Commit

Permalink
fix(//core/conversion/converters/impl/element_wise): Fix broadcast
Browse files Browse the repository at this point in the history
support

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 4, 2020
1 parent 0548540 commit a9f33e4
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 19 deletions.
30 changes: 17 additions & 13 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@ namespace converters {
namespace impl {
namespace {

nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, float scalar=1) {
nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name, float scalar=1) {
auto self_dims = self->getDimensions();
auto self_dims_vec = util::toVec(self_dims);
auto other_dims = other->getDimensions();
auto other_dims_vec = util::toVec(other_dims);
auto other_batch = other_dims_vec[0];

TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);
// TODO: Proper broadcast check
TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims) || util::volume(self_dims) == util::volume(other_dims) / other_batch, "Found inputs to elementwise operation do not have the same number of elements or is not broadcastable:\n Found: self " << self_dims << " other " << other_dims);

if (self_dims != other_dims) {
LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims);
auto other_shuffle = ctx->net->addShuffle(*other);
other_shuffle->setReshapeDimensions(self_dims);
other_shuffle->setName(std::string("[Reshape other to " + util::toStr(self_dims) + ']').c_str());
other = other_shuffle->getOutput(0);
auto self_shuffle = ctx->net->addShuffle(*self);
self_shuffle->setReshapeDimensions(util::toDimsPad(self_dims_vec, other_dims_vec.size()));
self_shuffle->setName(std::string("[Reshape self to " + util::toStr(self_dims) + " for broadcasting (" + name + ")]").c_str());
self = self_shuffle->getOutput(0);
}


Expand Down Expand Up @@ -72,7 +76,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto scalar = args[2].unwrapToScalar().to<float>();
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar);
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar);

TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n);

Expand All @@ -89,7 +93,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto scalar = args[2].unwrapToScalar().to<float>();
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar);
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar);

TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n);

Expand All @@ -106,7 +110,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto scalar = args[2].unwrapToScalar().to<float>();
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, scalar);
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, util::node_info(n), scalar);

TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);

Expand All @@ -122,7 +126,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
// Should implement self / other
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other);
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));

TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);

Expand All @@ -138,7 +142,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
// TODO: Remove with functionalization
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other);
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));

TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);

Expand All @@ -154,7 +158,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
// Should implement self * other
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other);
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));

TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);

Expand All @@ -170,7 +174,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
// TODO: Remove with functionalization
auto self = args[0].ITensor();
auto other = args[1].ITensor();
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other);
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));

TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);

Expand Down
1 change: 1 addition & 0 deletions tests/accuracy/accuracy_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AccuracyTests
std::cerr << "error loading the model\n";
return;
}
mod.eval();
}

void TearDown() {
Expand Down
4 changes: 2 additions & 2 deletions tests/accuracy/test_fp16_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
jit_total += targets.sizes()[0];
jit_correct += torch::sum(torch::eq(predictions, targets));
}
torch::Tensor jit_accuracy = jit_correct / jit_total;
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
auto extra_info = trtorch::ExtraInfo({input_shape});
Expand All @@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
trt_correct += torch::sum(torch::eq(predictions, targets));
}

torch::Tensor trt_accuracy = trt_correct / trt_total;
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
}
Expand Down
4 changes: 2 additions & 2 deletions tests/accuracy/test_fp32_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
jit_total += targets.sizes()[0];
jit_correct += torch::sum(torch::eq(predictions, targets));
}
torch::Tensor jit_accuracy = jit_correct / jit_total;
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
auto extra_info = trtorch::ExtraInfo({input_shape});
Expand All @@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
trt_correct += torch::sum(torch::eq(predictions, targets));
}

torch::Tensor trt_accuracy = trt_correct / trt_total;
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
}
Expand Down
4 changes: 2 additions & 2 deletions tests/accuracy/test_int8_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
jit_total += targets.sizes()[0];
jit_correct += torch::sum(torch::eq(predictions, targets));
}
torch::Tensor jit_accuracy = jit_correct / jit_total;
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

// Compile Graph
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
Expand All @@ -72,7 +72,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
trt_total += targets.sizes()[0];
trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
}
torch::Tensor trt_accuracy = trt_correct / trt_total;
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
}
Expand Down

0 comments on commit a9f33e4

Please sign in to comment.