Skip to content

Commit

Permalink
feat(aten::masked_fill): In progress implementation of masked_fill
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 28, 2021
1 parent 6aaba3b commit fa7d6d9
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 1 deletion.
21 changes: 21 additions & 0 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,27 @@ nvinfer1::ILayer* add_elementwise(
return ele;
}

nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
if (tensor->getType() != dtype) {
std::ostringstream tensor_id;
tensor_id << reinterpret_cast<int*>(tensor);

auto id_layer = ctx->net->addIdentity(*tensor);
TRTORCH_CHECK(id_layer, "Unable to create identity layer for ITensor: " << tensor_id.str());
auto casted_tensor = id_layer->getOutput(0);
casted_tensor->setType(dtype);

LOG_DEBUG(ctx->logger, "Casting ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype);

std::stringstream ss;
ss << "[Cast ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype << "]";
id_layer->setName(ss.str().c_str());
return casted_tensor;
} else {
return tensor;
}
}

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
3 changes: 3 additions & 0 deletions core/conversion/converters/converter_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ nvinfer1::ILayer* add_elementwise(
nvinfer1::ITensor* other,
const std::string& name);

// If an ITensor is of a type not dtype, add an Identity layer to cast it to dtype
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype);

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
24 changes: 23 additions & 1 deletion core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "NvInfer.h"
#include "c10/util/intrusive_ptr.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/prelude.h"
#include "torch/torch.h"
Expand Down Expand Up @@ -247,7 +248,28 @@ auto select_registrations TRTORCH_UNUSED =
add_split(ctx, n, args, true);
LOG_DEBUG("Converted split op into a list of IValues");
return true;
}});
}})
.pattern({
"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
LOG_DEBUG(args[1].unwrapToTensor());
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
auto val = args[2].unwrapToScalar().to<float>();
LOG_DEBUG(torch::full(util::toVec(self->getDimensions()), val));
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));

TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable");

auto new_layer = ctx->net->addSelect(*mask, *self, *val_t);
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill");

new_layer->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}});

} // namespace
} // namespace impl
Expand Down
16 changes: 16 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "torch/torch.h"

#include "core/conversion/evaluators/eval_macros.h"
#include "core/conversion/evaluators/eval_util.h"
#include "core/conversion/evaluators/evaluators.h"

namespace trtorch {
Expand Down Expand Up @@ -566,6 +567,21 @@ auto aten_registrations TRTORCH_UNUSED =
return {};
},
EvalOptions()})
.evaluator({c10::Symbol::fromQualString("aten::tensor"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto data = args.at(n->input(0)).IValue();
auto dtype = args.at(n->input(1)).IValue();
auto device = args.at(n->input(2)).IValue();
auto tensor = createTensorFromList(*data, *dtype, *device);
LOG_DEBUG(tensor);
if (tensor.dtype() == at::kByte) {
return tensor.to(at::kInt);
}
return tensor;
},
EvalOptions().validSchemas({
"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"
})})
.evaluator({c10::Symbol::fromQualString("aten::arange"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
int input_size = n->inputs().size();
Expand Down
201 changes: 201 additions & 0 deletions core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "ATen/InitialTensorOptions.h"
#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/jit_type.h"
#include "c10/util/irange.h"
#include "core/util/prelude.h"

namespace trtorch {
Expand Down Expand Up @@ -91,6 +94,204 @@ c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
}
}

void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
if (!elem_type->isSubtypeOf(c10::NumberType::get()) &&
elem_type != c10::BoolType::get()) {
std::stringstream error;
error << "Input must be of ints, floats, or bools, "
<< "got " << elem_type->repr_str();
// special case empty list torch.tensor([])
if (elem_type->isSubtypeOf(c10::TensorType::get())) {
if (empty_list) {
error << "\nEmpty lists default to List[Tensor]. Add a variable "
"annotation to the assignment to create an empty list "
"of another type (torch.jit.annotate(List[T, []]) where T "
"is the type of elements in the list for Python 2)";
}
}
TRTORCH_THROW_ERROR(error.str());
}
}

void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
if (seq_size != n) {
TRTORCH_THROW_ERROR(
"Expected sequence of length "
<< n
<< " at dim "
<< dim
<< " (got "
<< seq_size
<< ")");
}
}



template <typename DTYPE>
void storeLastDimension(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int elementSize,
at::ArrayRef<torch::jit::IValue> obj) {
auto n = sizes[dim];
auto seq_size = obj.size();
checkSequenceSize(n, dim, seq_size);
for (const auto i : c10::irange(n)) {
*(DTYPE*)data = obj[i].to<DTYPE>();
data += strides[dim] * elementSize;
}
}


void storeLastDimensionFloat(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int elementSize,
at::ArrayRef<torch::jit::IValue> obj) {
auto n = sizes[dim];
auto seq_size = obj.size();
checkSequenceSize(n, dim, seq_size);
for (int64_t i = 0; i < n; i++) {
*(float*)data = static_cast<float>(obj[i].to<double>());
data += strides[dim] * elementSize;
}
}

void storeLastDimensionHalf(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int elementSize,
at::ArrayRef<torch::jit::IValue> obj) {
auto n = sizes[dim];
auto seq_size = obj.size();
checkSequenceSize(n, dim, seq_size);
for (int64_t i = 0; i < n; i++) {
*(at::Half*)data = at::convert<at::Half, double>(obj[i].to<double>());
data += strides[dim] * elementSize;
}
}

void recursiveStore(
char* data,
const std::vector<int64_t>& sizes,
const c10::ArrayRef<int64_t>& strides,
int64_t dim,
int tenElementSize,
const torch::jit::IValue& obj) {
auto ndim = sizes.size();
auto n = sizes[dim];
auto seq = obj.toListRef();
checkSequenceSize(n, dim, seq.size());
if (dim + 1 < static_cast<long>(ndim)) {
for (const auto i : c10::irange(n)) {
recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]);
data += strides[dim] * tenElementSize;
}
} else {
if (obj.isIntList()) {
storeLastDimension<int64_t>(
data, sizes, strides, dim, tenElementSize, seq);
} else if (obj.isBoolList()) {
storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
} else if (obj.isDoubleList()) {
if (tenElementSize ==
static_cast<int>(elementSize(at::ScalarType::Double))) {
storeLastDimension<double>(
data, sizes, strides, dim, tenElementSize, seq);
} else if (
tenElementSize ==
static_cast<int>(elementSize(at::ScalarType::Float))) {
storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
} else if (
tenElementSize ==
static_cast<int>(elementSize(at::ScalarType::Half))) {
storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
} else {
TORCH_INTERNAL_ASSERT(false);
}
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
}

at::Tensor castTensorTo(
at::Tensor self,
const torch::jit::IValue& dtype,
const torch::jit::IValue& device) {
at::ScalarType scalar_type =
dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
c10::Device dev = device.isNone() ? self.device() : device.toDevice();
if (scalar_type != self.scalar_type() || dev != self.device()) {
self = self.to(dev, scalar_type);
}
return self;
}

std::vector<int64_t> compute_sizes(const torch::jit::IValue& seq) {
std::vector<int64_t> sizes;
auto seq_recur = seq.toList();
while (true) {
sizes.push_back(seq_recur.size());
if (seq_recur.size() == 0 || !seq_recur.get(0).isList()) {
break;
}
seq_recur = seq_recur.get(0).toList();
}
return sizes;
}

at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device) {
auto elem_type = data.type();
while (auto list_type = elem_type->cast<c10::ListType>()) {
elem_type = list_type->getElementType();
}
auto sizes = compute_sizes(data);
checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type);
if (initial_scalar_type == at::ScalarType::Double) {
initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype());
}

auto tensor =
at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));

if (tensor.numel() != 0) {
recursiveStore(
(char*)tensor.data_ptr(),
sizes,
tensor.strides(),
0,
tensor.element_size(),
data);
}

tensor = castTensorTo(tensor, dtype, device);
auto default_type = at::typeMetaToScalarType(at::get_default_dtype());

if (dtype.isNone() && tensor.scalar_type() != default_type &&
tensor.numel() == 0) {
LOG_WARNING(
"Creating a tensor from an empty "
<< elem_type->repr_str()
<< "list will create a tensor of default floating point type (currently "
<< default_type
<< ") in python but a tensor of type "
<< elem_type->repr_str()
<< " in torchscript.\n"
<< "Pass in a dtype argument to ensure consistent behavior");
}

return tensor;
}

} // namespace evaluators
} // namespace conversion
} // namespace core
Expand Down
1 change: 1 addition & 0 deletions core/conversion/evaluators/eval_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace conversion {
namespace evaluators {

c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device);

} // namespace evaluators
} // namespace conversion
Expand Down
2 changes: 2 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
return stream << "Int8";
case nvinfer1::DataType::kINT32:
return stream << "Int32";
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
default:
return stream << "Unknown Data Type";
}
Expand Down
43 changes: 43 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "core/lowering/passes/passes.h"


TEST(Converters, ATenSelectIntConvertsCorrectly) {
const auto graph = R"IR(
Expand Down Expand Up @@ -398,3 +400,44 @@ TEST(Converters, ATenSplitAndAddConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%44 : Device = prim::Constant[value="cuda"]()
%8 : bool = prim::Constant[value=0]()
%7 : None = prim::Constant()
%1 : int = prim::Constant[value=0]() # bert.py:5:26
%2 : int = prim::Constant[value=1]() # bert.py:5:32
%33 : int = prim::Constant[value=2]() # bert.py:6:31
%3 : int[] = prim::ListConstruct(%1, %1, %2)
%4 : int[] = prim::ListConstruct(%2, %2, %1)
%5 : int[][] = prim::ListConstruct(%3, %4)
%5 : int[][][] = prim::ListConstruct(%5)
%9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11
%mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11
%mask.2 : Tensor = trt::const(%mask.1)
%34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11
return (%34, %mask.2))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::zeros({1, 2, 3}, {at::kCUDA});


auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
trtorch::core::lowering::passes::RemoveNOPs(g);
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

std::cout << jit_results[0] << trt_results[0].reshape_as(jit_results[0]) << std::endl;

std::cout << trt_results[1].reshape_as(jit_results[0]) << std::endl;

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit fa7d6d9

Please sign in to comment.