Skip to content

Commit

Permalink
feat(core//conversion): Implement converter for torch unbind
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Apr 19, 2022
1 parent b8b8fce commit 268a49b
Show file tree
Hide file tree
Showing 3 changed files with 600 additions and 564 deletions.
50 changes: 31 additions & 19 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,32 @@ namespace converters {
namespace impl {
namespace {

bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) {
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
auto in = args[0].ITensor();
auto axis = args[2].unwrapToInt();
auto inDimSize = in->getDimensions().d[axis];
auto numOutputs = 1, numRemainder = 0;
auto numOutputs = 1, numRemainder = 0, axis = 0;
std::vector<int64_t> sizes;

if (split_list) {
sizes = args[1].unwrapToIntList().vec();
numOutputs = sizes.size();
if (unbind) {
axis = args[1].unwrapToInt();
numOutputs = in->getDimensions().d[axis];
sizes.insert(sizes.end(), numOutputs, 1);
} else {
auto split_size = args[1].unwrapToInt();
numOutputs = inDimSize / split_size;
numRemainder = inDimSize % split_size;
for (int64_t i = 0; i < numOutputs; i++) {
sizes.push_back(split_size);
}
if (numRemainder) {
numOutputs += 1;
sizes.push_back(numRemainder);
axis = args[2].unwrapToInt();
auto inDimSize = in->getDimensions().d[axis];
if (split_list) {
sizes = args[1].unwrapToIntList().vec();
numOutputs = sizes.size();
} else {
auto split_size = args[1].unwrapToInt();
numOutputs = inDimSize / split_size;
numRemainder = inDimSize % split_size;
for (int64_t i = 0; i < numOutputs; i++) {
sizes.push_back(split_size);
}
if (numRemainder) {
numOutputs += 1;
sizes.push_back(numRemainder);
}
}
}

Expand Down Expand Up @@ -340,19 +346,25 @@ auto select_registrations TORCHTRT_UNUSED =
}})
.pattern({"aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
add_split(ctx, n, args, true);
add_split(ctx, n, args, true, false);
LOG_DEBUG("Converted split op into a list of IValues");
return true;
}})
.pattern({"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
add_split(ctx, n, args, false);
add_split(ctx, n, args, false, false);
LOG_DEBUG("Converted split op into a list of IValues");
return true;
}})
.pattern({"aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
add_split(ctx, n, args, true);
add_split(ctx, n, args, true, false);
LOG_DEBUG("Converted split op into a list of IValues");
return true;
}})
.pattern({"aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
add_split(ctx, n, args, false, true);
LOG_DEBUG("Converted split op into a list of IValues");
return true;
}})
Expand Down
5 changes: 1 addition & 4 deletions core/lowering/register_trt_placeholder_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
RegisterOperators trt_placeholder_ops_reg({
/// Op marks a Tensor to be conveted from an Torch Tensor
/// to a TRT constant Tensor
Operator(
"trt::const(Tensor val) -> Tensor",
[](Stack& stack) { /*noop*/ },
aliasAnalysisFromSchema()),
Operator("trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()),
});

} // namespace jit
Expand Down
Loading

0 comments on commit 268a49b

Please sign in to comment.