Skip to content

Commit

Permalink
fix: Add aten::full_like evaluator
Browse files Browse the repository at this point in the history
- Complete set of required operators for full coverage of HuggingFace
T5 model
- Add evaluator for `full_like` with sufficient generality to double as
functional operator for `ones_like` and `zeros_like`, with minor
additions
- Add thorough testing to ensure type and shape inheritance for returned
tensors agrees with documentation
  • Loading branch information
gs-olive committed Jan 14, 2023
1 parent 0d32562 commit 0aaeecb
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
49 changes: 48 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,53 @@ auto aten_registrations TORCHTRT_UNUSED =
auto out_tensor = torch::full(args.at(n->input(0)).unwrapToIntList().vec(), scalar_value, options);
return out_tensor;
}})
.evaluator(
{c10::Symbol::fromQualString("aten::full_like"),
// aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None,
// Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor)
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Override options related to layout and device for TensorRT
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
auto input_tensor_var = args.at(n->input(0));

std::vector<int64_t> input_shape;
c10::ScalarType input_dtype;

// Extract data type and shape of input tensor
if (input_tensor_var.isITensor()) {
auto tensor = input_tensor_var.ITensor();
input_shape = util::toVec(tensor->getDimensions());
input_dtype = util::TRTDataTypeToScalarType(tensor->getType());
} else if (input_tensor_var.IValue()->isTensor()) {
auto tensor = input_tensor_var.unwrapToTensor();
input_shape = tensor.sizes().vec();
input_dtype = tensor.scalar_type();
} else if (input_tensor_var.IValue()->isCustomClass()) {
auto tensor = input_tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
input_shape = util::toVec(tensor->getDimensions());
input_dtype = util::TRTDataTypeToScalarType(tensor->getType());
} else {
TORCHTRT_THROW_ERROR(
"Invalid IValue type. IValue is not some class of torch::Tensor or nvinfer1::ITensor. Found: "
<< input_tensor_var.IValue()->type());
}

// If specified, use third input arg to determine data type, otherwise default to input tensor data type
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
} else {
options = options.dtype(input_dtype);
}

// Generate full tensor with specified input options
auto scalar_value = args.at(n->input(1)).unwrapToScalar();
auto out_tensor = torch::full(input_shape, scalar_value, options);
return out_tensor;
},
EvalOptions().validSchemas(
{R"SIG(aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None,
Layout? layout=None, Device? device=None, bool? pin_memory=None,
MemoryFormat? memory_format=None) -> (Tensor))SIG"})})
.evaluator(
{c10::Symbol::fromQualString("aten::slice"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
Expand Down Expand Up @@ -821,4 +868,4 @@ auto aten_registrations TORCHTRT_UNUSED =
} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
65 changes: 65 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,71 @@ TEST(Evaluators, FullEvaluatesCorrectly) {
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
}

TEST(Evaluators, FullLikeEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%9 : None = prim::Constant()
%13 : float = prim::Constant[value=1.3]()
%14 : int = prim::Constant[value=4]()
%35 : Device = prim::Constant[value="cuda:0"]()
%19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
return (%19))IR";

auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA});

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype());
}

TEST(Evaluators, FullLikeNewDtypeEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%9 : None = prim::Constant()
%13 : Scalar = prim::Constant[value=1]()
%14 : int = prim::Constant[value=11]()
%35 : Device = prim::Constant[value="cuda:0"]()
%19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
return (%19))IR";

auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA});

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype());
}

TEST(Evaluators, FullLikeOldDtypeEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%9 : None = prim::Constant()
%13 : Scalar = prim::Constant[value=1.5]()
%35 : Device = prim::Constant[value="cuda:0"]()
%19 : Tensor = aten::full_like(%x.1, %13, %9, %9, %35, %9, %9)
return (%19))IR";

auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA}).to(torch::kInt32);

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype());
}

TEST(Evaluators, OnesDataTypeEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down

0 comments on commit 0aaeecb

Please sign in to comment.