Skip to content

Commit

Permalink
Added embedding_bag and fixed unbind int
Browse files Browse the repository at this point in the history
  • Loading branch information
ynimmaga committed Mar 6, 2024
1 parent a3f919a commit 238a3f0
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, options):
"torch.ops.aten._adaptive_avg_pool2d.default": None,
"torch.ops.aten._adaptive_avg_pool3d.default": None,
"torch.ops.aten._convolution.default": None,
"torch.ops.aten._embedding_bag.default": None,
"torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None,
"torch.ops.aten._local_scalar_dense.default": None,
"torch.ops.aten._log_softmax.default": None,
Expand Down
20 changes: 16 additions & 4 deletions src/frontends/pytorch/src/op/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_embedding_bag(const NodeContext& context) {
OutputVector translate_embedding_bag_common(const NodeContext& context) {
// aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False,
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
num_inputs_check(context, 9, 9);
// we have only EmbeddingBagSum case support, check it before translation
auto mode = context.const_input<int64_t>(4);
PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation");
Expand All @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
// with offsets case
auto offsets = context.get_input(2);
offsets = context.mark_node(std::make_shared<ov::op::v0::Convert>(offsets, element::i32));
auto include_last_offset = context.const_input<bool>(7);
bool include_last_offset = false;
if (!context.input_is_none(7))
include_last_offset = context.const_input<bool>(7);
PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported");
// no per_sample_wights
if (context.input_is_none(6)) {
Expand All @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
return {result, zero, zero, zero};
};

OutputVector translate_embedding_bag(const NodeContext& context) {
num_inputs_check(context, 9, 9);
return translate_embedding_bag_common(context);
}

OutputVector translate_embedding_bag_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
ov::OutputVector output = translate_embedding_bag_common(context);
return {context.mark_node(make_list_construct(output))};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
14 changes: 10 additions & 4 deletions src/frontends/pytorch/src/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ OutputVector translate_chunk_fx(const NodeContext& context) {
}

OutputVector translate_unbind_int_fx(const NodeContext& context) {
num_inputs_check(context, 2, 3);
num_inputs_check(context, 1, 3);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto dim_val = context.const_input<int>(1);
Output<Node> dim;
int64_t dim_val = 0;
if (context.input_is_none(1)) {
dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
} else {
dim = context.get_input(1);
dim_val = context.const_input<int>(1);
}

auto shape = input.get_shape();

if (dim_val < 0) {
dim_val = static_cast<int>(shape.size()) + dim_val;
}
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ OP_CONVERTER(translate_constant_pad_nd_fx);
OP_CONVERTER(translate_cumsum_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_div_fx);
OP_CONVERTER(translate_embedding_bag_fx);
OP_CONVERTER(translate_expand_fx);
OP_CONVERTER(translate_fake_quantize_per_channel_affine_fx);
OP_CONVERTER(translate_fake_quantize_per_tensor_affine_fx);
Expand Down Expand Up @@ -691,6 +692,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten._adaptive_avg_pool2d.default", op::translate_adaptive_avg_pool2d},
{"aten._adaptive_avg_pool3d.default", op::translate_adaptive_avg_pool3d},
{"aten._convolution.default", op::translate_convolution},
{"aten._embedding_bag.default", op::translate_embedding_bag_fx},
{"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default",
op::translate_fake_quantize_per_tensor_affine_fx},
{"aten._local_scalar_dense.default", op::skip_node},
Expand Down

0 comments on commit 238a3f0

Please sign in to comment.