diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 3700dc52be..cbe40c7aee 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -122,6 +122,27 @@ nvinfer1::ILayer* add_elementwise( return ele; } +nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) { + if (tensor->getType() != dtype) { + std::ostringstream tensor_id; + tensor_id << reinterpret_cast(tensor); + + auto id_layer = ctx->net->addIdentity(*tensor); + TRTORCH_CHECK(id_layer, "Unable to create identity layer for ITensor: " << tensor_id.str()); + auto casted_tensor = id_layer->getOutput(0); + casted_tensor->setType(dtype); + + LOG_DEBUG(ctx->logger, "Casting ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype); + + std::stringstream ss; + ss << "[Cast ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype << "]"; + id_layer->setName(ss.str().c_str()); + return casted_tensor; + } else { + return tensor; + } +} + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index cd1d876c5d..03273eaeb1 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -42,6 +42,9 @@ nvinfer1::ILayer* add_elementwise( nvinfer1::ITensor* other, const std::string& name); +// If an ITensor is of a type not dtype, add an Identity layer to cast it to dtype +nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype); + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 616c9907e6..9b4e39c2b5 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -3,6 +3,7 @@ #include "NvInfer.h" #include "c10/util/intrusive_ptr.h" #include "core/conversion/converters/converters.h" +#include "core/conversion/converters/converter_util.h" #include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/util/prelude.h" #include "torch/torch.h" @@ -247,7 +248,28 @@ auto select_registrations TRTORCH_UNUSED = add_split(ctx, n, args, true); LOG_DEBUG("Converted split op into a list of IValues"); return true; - }}); + }}) + .pattern({ + "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + LOG_DEBUG(args[1].unwrapToTensor()); + auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL); + auto val = args[2].unwrapToScalar().to(); + LOG_DEBUG(torch::full(util::toVec(self->getDimensions()), val)); + auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val)); + + TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable"); + + auto new_layer = ctx->net->addSelect(*mask, *self, *val_t); + TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill"); + + new_layer->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + }}); } // namespace } // namespace impl diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index c3c4f5b02b..8e367f9779 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -8,6 +8,7 @@ #include "torch/torch.h" #include "core/conversion/evaluators/eval_macros.h" +#include "core/conversion/evaluators/eval_util.h" #include "core/conversion/evaluators/evaluators.h" namespace trtorch { @@ -566,6 +567,21 @@ auto aten_registrations TRTORCH_UNUSED = return {}; }, EvalOptions()}) + .evaluator({c10::Symbol::fromQualString("aten::tensor"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto data = args.at(n->input(0)).IValue(); + auto dtype = args.at(n->input(1)).IValue(); + auto device = args.at(n->input(2)).IValue(); + auto tensor = createTensorFromList(*data, *dtype, *device); + LOG_DEBUG(tensor); + if (tensor.dtype() == at::kByte) { + return tensor.to(at::kInt); + } + return tensor; + }, + EvalOptions().validSchemas({ + "aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)" + })}) .evaluator({c10::Symbol::fromQualString("aten::arange"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { int input_size = n->inputs().size(); diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 033ebfc298..6d14fcfd69 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -1,6 +1,9 @@ +#include "ATen/InitialTensorOptions.h" #include "ATen/core/List.h" #include "ATen/core/functional.h" #include "ATen/core/ivalue.h" +#include "ATen/core/jit_type.h" +#include "c10/util/irange.h" #include "core/util/prelude.h" namespace trtorch { @@ -91,6 +94,204 @@ c10::optional toIValue(const torch::jit::Value* v) { } } +void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { + if (!elem_type->isSubtypeOf(c10::NumberType::get()) && + elem_type != c10::BoolType::get()) { + std::stringstream error; + error << "Input must be of ints, floats, or bools, " + << "got " << elem_type->repr_str(); + // special case empty list torch.tensor([]) + if (elem_type->isSubtypeOf(c10::TensorType::get())) { + if (empty_list) { + error << "\nEmpty lists default to List[Tensor]. Add a variable " + "annotation to the assignment to create an empty list " + "of another type (torch.jit.annotate(List[T, []]) where T " + "is the type of elements in the list for Python 2)"; + } + } + TRTORCH_THROW_ERROR(error.str()); + } +} + +void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { + if (seq_size != n) { + TRTORCH_THROW_ERROR( + "Expected sequence of length " + << n + << " at dim " + << dim + << " (got " + << seq_size + << ")"); + } +} + + + +template +void storeLastDimension( + char* data, + const std::vector& sizes, + const c10::ArrayRef& strides, + int64_t dim, + int elementSize, + at::ArrayRef obj) { + auto n = sizes[dim]; + auto seq_size = obj.size(); + checkSequenceSize(n, dim, seq_size); + for (const auto i : c10::irange(n)) { + *(DTYPE*)data = obj[i].to(); + data += strides[dim] * elementSize; + } +} + + +void storeLastDimensionFloat( + char* data, + const std::vector& sizes, + const c10::ArrayRef& strides, + int64_t dim, + int elementSize, + at::ArrayRef obj) { + auto n = sizes[dim]; + auto seq_size = obj.size(); + checkSequenceSize(n, dim, seq_size); + for (int64_t i = 0; i < n; i++) { + *(float*)data = static_cast(obj[i].to()); + data += strides[dim] * elementSize; + } +} + +void storeLastDimensionHalf( + char* data, + const std::vector& sizes, + const c10::ArrayRef& strides, + int64_t dim, + int elementSize, + at::ArrayRef obj) { + auto n = sizes[dim]; + auto seq_size = obj.size(); + checkSequenceSize(n, dim, seq_size); + for (int64_t i = 0; i < n; i++) { + *(at::Half*)data = at::convert(obj[i].to()); + data += strides[dim] * elementSize; + } +} + +void recursiveStore( + char* data, + const std::vector& sizes, + const c10::ArrayRef& strides, + int64_t dim, + int tenElementSize, + const torch::jit::IValue& obj) { + auto ndim = sizes.size(); + auto n = sizes[dim]; + auto seq = obj.toListRef(); + checkSequenceSize(n, dim, seq.size()); + if (dim + 1 < static_cast(ndim)) { + for (const auto i : c10::irange(n)) { + recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]); + data += strides[dim] * tenElementSize; + } + } else { + if (obj.isIntList()) { + storeLastDimension( + data, sizes, strides, dim, tenElementSize, seq); + } else if (obj.isBoolList()) { + storeLastDimension(data, sizes, strides, dim, tenElementSize, seq); + } else if (obj.isDoubleList()) { + if (tenElementSize == + static_cast(elementSize(at::ScalarType::Double))) { + storeLastDimension( + data, sizes, strides, dim, tenElementSize, seq); + } else if ( + tenElementSize == + static_cast(elementSize(at::ScalarType::Float))) { + storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq); + } else if ( + tenElementSize == + static_cast(elementSize(at::ScalarType::Half))) { + storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq); + } else { + TORCH_INTERNAL_ASSERT(false); + } + } else { + TORCH_INTERNAL_ASSERT(false); + } + } +} + +at::Tensor castTensorTo( + at::Tensor self, + const torch::jit::IValue& dtype, + const torch::jit::IValue& device) { + at::ScalarType scalar_type = + dtype.isNone() ? self.scalar_type() : dtype.toScalarType(); + c10::Device dev = device.isNone() ? self.device() : device.toDevice(); + if (scalar_type != self.scalar_type() || dev != self.device()) { + self = self.to(dev, scalar_type); + } + return self; +} + +std::vector compute_sizes(const torch::jit::IValue& seq) { + std::vector sizes; + auto seq_recur = seq.toList(); + while (true) { + sizes.push_back(seq_recur.size()); + if (seq_recur.size() == 0 || !seq_recur.get(0).isList()) { + break; + } + seq_recur = seq_recur.get(0).toList(); + } + return sizes; +} + +at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device) { + auto elem_type = data.type(); + while (auto list_type = elem_type->cast()) { + elem_type = list_type->getElementType(); + } + auto sizes = compute_sizes(data); + checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0); + at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type); + if (initial_scalar_type == at::ScalarType::Double) { + initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype()); + } + + auto tensor = + at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)); + + if (tensor.numel() != 0) { + recursiveStore( + (char*)tensor.data_ptr(), + sizes, + tensor.strides(), + 0, + tensor.element_size(), + data); + } + + tensor = castTensorTo(tensor, dtype, device); + auto default_type = at::typeMetaToScalarType(at::get_default_dtype()); + + if (dtype.isNone() && tensor.scalar_type() != default_type && + tensor.numel() == 0) { + LOG_WARNING( + "Creating a tensor from an empty " + << elem_type->repr_str() + << "list will create a tensor of default floating point type (currently " + << default_type + << ") in python but a tensor of type " + << elem_type->repr_str() + << " in torchscript.\n" + << "Pass in a dtype argument to ensure consistent behavior"); + } + + return tensor; +} + } // namespace evaluators } // namespace conversion } // namespace core diff --git a/core/conversion/evaluators/eval_util.h b/core/conversion/evaluators/eval_util.h index 1e31ddfe46..0a871f4cfa 100644 --- a/core/conversion/evaluators/eval_util.h +++ b/core/conversion/evaluators/eval_util.h @@ -8,6 +8,7 @@ namespace conversion { namespace evaluators { c10::optional toIValue(const torch::jit::Value* v); +at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device); } // namespace evaluators } // namespace conversion diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 250c30d04d..8a8b399c06 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -31,6 +31,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& return stream << "Int8"; case nvinfer1::DataType::kINT32: return stream << "Int32"; + case nvinfer1::DataType::kBOOL: + return stream << "Bool"; default: return stream << "Unknown Data Type"; } diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 3b3967b236..52afd7ca09 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -4,6 +4,8 @@ #include "gtest/gtest.h" #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" +#include "core/lowering/passes/passes.h" + TEST(Converters, ATenSelectIntConvertsCorrectly) { const auto graph = R"IR( @@ -398,3 +400,44 @@ TEST(Converters, ATenSplitAndAddConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } } + +TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %44 : Device = prim::Constant[value="cuda"]() + %8 : bool = prim::Constant[value=0]() + %7 : None = prim::Constant() + %1 : int = prim::Constant[value=0]() # bert.py:5:26 + %2 : int = prim::Constant[value=1]() # bert.py:5:32 + %33 : int = prim::Constant[value=2]() # bert.py:6:31 + %3 : int[] = prim::ListConstruct(%1, %1, %2) + %4 : int[] = prim::ListConstruct(%2, %2, %1) + %5 : int[][] = prim::ListConstruct(%3, %4) + %5 : int[][][] = prim::ListConstruct(%5) + %9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11 + %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 + %mask.2 : Tensor = trt::const(%mask.1) + %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11 + return (%34, %mask.2))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::zeros({1, 2, 3}, {at::kCUDA}); + + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + trtorch::core::lowering::passes::RemoveNOPs(g); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + std::cout << jit_results[0] << trt_results[0].reshape_as(jit_results[0]) << std::endl; + + std::cout << trt_results[1].reshape_as(jit_results[0]) << std::endl; + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} \ No newline at end of file