Skip to content

Commit

Permalink
feat(aten::view): Adds support for ATen view also fixes some tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 20, 2020
1 parent c4b62a6 commit 24b422e
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 26 deletions.
18 changes: 18 additions & 0 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
}).pattern({
"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto ex_tensor = torch::rand(in_shape);
auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(new_shape));
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}
});
Expand Down
7 changes: 7 additions & 0 deletions tests/core/converters/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
load("//tests/core/converters:converter_test.bzl", "converter_test")

config_setting(
name = "use_pre_cxx11_abi",
values = {
"define": "abi=pre_cxx11_abi",
}
)

converter_test(
name = "test_activation"
)
Expand Down
6 changes: 0 additions & 6 deletions tests/core/converters/converter_test.bzl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
config_setting(
name = "use_pre_cxx11_abi",
values = {
"define": "abi=pre_cxx11_abi",
}
)

def converter_test(name, visibility=None):
native.cc_test(
Expand Down
24 changes: 24 additions & 0 deletions tests/core/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,29 @@ TEST(Converters, ATenReshapeConvertsCorrectly) {
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenViewConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=2]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : Tensor = aten::view(%0, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}
83 changes: 65 additions & 18 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,72 @@
import torchvision.models as models

models = {
"alexnet": models.alexnet(pretrained=True),
"vgg16": models.vgg16(pretrained=True),
"squeezenet": models.squeezenet1_0(pretrained=True),
"densenet": models.densenet161(pretrained=True),
"inception_v3": models.inception_v3(pretrained=True),
"alexnet": {
"model": models.alexnet(pretrained=True),
"path": "both"
},
"vgg16": {
"model": models.vgg16(pretrained=True),
"path": "both"
},
"squeezenet": {
"model": models.squeezenet1_0(pretrained=True),
"path": "both"
},
"densenet": {
"model": models.densenet161(pretrained=True),
"path": "both"
},
"inception_v3": {
"model": models.inception_v3(pretrained=True),
"path": "both"
},
#"googlenet": models.googlenet(pretrained=True),
"shufflenet": models.shufflenet_v2_x1_0(pretrained=True),
"mobilenet_v2": models.mobilenet_v2(pretrained=True),
"resnext50_32x4d": models.resnext50_32x4d(pretrained=True),
"wideresnet50_2": models.wide_resnet50_2(pretrained=True),
"mnasnet": models.mnasnet1_0(pretrained=True),
"resnet18": torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True),
"resnet50": torch.hub.load('pytorch/vision:v0.5.0', 'resnet50', pretrained=True)}
"shufflenet": {
"model": models.shufflenet_v2_x1_0(pretrained=True),
"path": "both"
},
"mobilenet_v2": {
"model": models.mobilenet_v2(pretrained=True),
"path": "both"
},
"resnext50_32x4d": {
"model": models.resnext50_32x4d(pretrained=True),
"path": "both"
},
"wideresnet50_2": {
"model": models.wide_resnet50_2(pretrained=True),
"path": "both"
},
"mnasnet": {
"model": models.mnasnet1_0(pretrained=True),
"path": "both"
},
"resnet18": {
"model": torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True),
"path": "both"
},
"resnet50": {
"model":torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True),
"path": "both"
},
"fcn_resnet101": {
"model": torch.hub.load('pytorch/vision:v0.6.0', 'fcn_resnet101', pretrained=True),
"path": "script"
},
"ssd": {
"model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"),
"path": "trace"
}
}

for n, m in models.items():
print("Downloading {}".format(n))
m = m.eval().cuda()
x = torch.ones((1, 3, 224, 224)).cuda()
trace_model = torch.jit.trace(m, x)
torch.jit.save(trace_model, n + '_traced.jit.pt')
script_model = torch.jit.script(m)
torch.jit.save(script_model, n + '_scripted.jit.pt')
m["model"] = m["model"].eval().cuda()
x = torch.ones((1, 3, 300, 300)).cuda()
if m["path"] == "both" or m["path"] == "trace":
trace_model = torch.jit.trace(m["model"], [x])
torch.jit.save(trace_model, n + '_traced.jit.pt')
if m["path"] == "both" or m["path"] == "script":
script_model = torch.jit.script(m["model"])
torch.jit.save(script_model, n + '_scripted.jit.pt')
2 changes: 0 additions & 2 deletions tests/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ cc_library(
],
deps = [
"@tensorrt//:nvinfer",
"@libtorch//:libtorch",
"@libtorch//:caffe2",
"//core/conversion",
"//core/util:prelude",
"//cpp/api:trtorch",
Expand Down

0 comments on commit 24b422e

Please sign in to comment.