Skip to content

Commit

Permalink
Merge branch 'op_support_2024.0' into aot_autograd_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
suryasidd committed Feb 9, 2024
2 parents e2b86c9 + 7030f32 commit 6de206b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ OutputVector translate_log(const NodeContext& context) {
OutputVector translate_log_sigmoid(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
auto sigmoid = context.mark_node(std::make_shared<v0::Sigmoid>(x));
sigmoid = context.mark_node(std::make_shared<v0::Convert>(sigmoid, element::f32));
auto log = context.mark_node(std::make_shared<v0::Log>(sigmoid));
return {log};
};
Expand Down
45 changes: 12 additions & 33 deletions src/frontends/pytorch/src/op/min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ OutputVector translate_max(const NodeContext& context) {
align_eltwise_input_types(context, x, y, true);
return {context.mark_node(std::make_shared<v1::Maximum>(x, y))};
}
// torch.max(input, dim, keepdim), returns values and indicies
// torch.max(input, dim, keepdim), returns values and indices
auto axes_node = context.get_input(1);
auto axis_const = context.const_input<int64_t>(1);
auto keepdims = context.const_input<bool>(2);
auto values = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes_node, keepdims));
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{}, 1));
auto topk =
context.mark_node(std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MAX, v3::TopK::SortType::NONE));
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
auto indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (!keepdims) {
indicies = context.mark_node(std::make_shared<v0::Squeeze>(indicies, axes_node));
indices = context.mark_node(std::make_shared<v0::Squeeze>(indices, axes_node));
}
return {values, indicies};
return {values, indices};
};

OutputVector translate_max_dim(const NodeContext& context) {
Expand All @@ -68,36 +68,15 @@ OutputVector translate_max_dim(const NodeContext& context) {
auto values = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes_node, keepdims));
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{}, 1));
auto topk = std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MAX, v3::TopK::SortType::NONE);
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
auto indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (!keepdims) {
indicies = std::make_shared<v0::Squeeze>(indicies, axes_node);
indices = std::make_shared<v0::Squeeze>(indices, axes_node);
}
return {values, indicies};
return {values, indices};
};

OutputVector translate_max_dim_fx(const NodeContext& context) {
// torch.max (same for torch.min) actually has two interfaces smashed together:
// torch.max(x, dim, keepdim) and torch.max(x, y)
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto axes_node = context.get_input(1);
auto axis_const = context.const_input<int64_t>(1);

bool keepdims = false;
if (!context.input_is_none(2)) {
keepdims = context.const_input<bool>(2);
}

auto values = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes_node, keepdims));
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{}, 1));
auto topk = std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MAX, v3::TopK::SortType::NONE);
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (!keepdims) {
indicies = std::make_shared<v0::Squeeze>(indicies, axes_node);
}
ov::OutputVector out_vec;
out_vec.push_back(values);
out_vec.push_back(indicies);
ov::OutputVector out_vec = translate_max_dim(context);
return {context.mark_node(make_list_construct(out_vec))};
};

Expand All @@ -117,19 +96,19 @@ OutputVector translate_min(const NodeContext& context) {
align_eltwise_input_types(context, x, y, true);
return {context.mark_node(std::make_shared<v1::Minimum>(x, y))};
}
// torch.min(input, dim, keepdim), returns values and indicies
// torch.min(input, dim, keepdim), returns values and indices
auto axes_node = context.get_input(1);
auto axis_const = context.const_input<int64_t>(1);
auto keepdims = context.const_input<bool>(2);
auto values = context.mark_node(std::make_shared<v1::ReduceMin>(x, axes_node, keepdims));
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{}, 1));
auto topk =
context.mark_node(std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MIN, v3::TopK::SortType::NONE));
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
auto indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
if (!keepdims) {
indicies = context.mark_node(std::make_shared<v0::Squeeze>(indicies, axes_node));
indices = context.mark_node(std::make_shared<v0::Squeeze>(indices, axes_node));
}
return {values, indicies};
return {values, indices};
};

OutputVector translate_maximum(const NodeContext& context) {
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ OutputVector translate_unbind_int_fx(const NodeContext& context) {
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto dim_val = context.const_input<int>(1);
auto shape = context.get_input(0).get_shape();
auto shape = input.get_shape();

if (dim_val < 0) {
dim_val = static_cast<int>(shape.size()) + dim_val;
}

auto num_splits = static_cast<int>(shape[dim_val]);
auto chunk = context.mark_node(std::make_shared<v1::Split>(context.get_input(0), dim, num_splits));
auto chunk = context.mark_node(std::make_shared<v1::Split>(input, dim, num_splits));

return {context.mark_node(make_list_construct(chunk->outputs()))};
}
Expand Down
24 changes: 4 additions & 20 deletions src/frontends/pytorch/src/op/var_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,35 +75,19 @@ OutputVector translate_var_mean(const NodeContext& context) {
};

OutputVector translate_var_mean_fx(const NodeContext& context) {
num_inputs_check(context, 1, 4);
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
bool unbiased = false;
auto num_elements = numel(context, data);
std::shared_ptr<ov::Node> mean, t_mean;
std::shared_ptr<ov::Node> mean;
ov::Output<ov::Node> axes;

axes = context.get_input(1);
auto axis_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, true));
t_mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, true));
auto reduced_dims = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
reduced_dims = context.mark_node(std::make_shared<v8::Gather>(reduced_dims, axes, zero));
num_elements = context.mark_node(std::make_shared<v1::ReduceProd>(reduced_dims, zero, false));

auto sub_v = context.mark_node(std::make_shared<v1::Subtract>(data, t_mean));
auto sub_v = context.mark_node(std::make_shared<v1::Subtract>(data, mean));
auto sqr_sub = context.mark_node(std::make_shared<v1::Multiply>(sub_v, sub_v));
auto var = context.mark_node(std::make_shared<v1::ReduceMean>(sqr_sub, axes, true));
// if unbiased=true Bessel’s correction will be used
// Correct bias in calculating variance, by dividing it over (N - 1) instead on N
if (unbiased) {
num_elements = context.mark_node(std::make_shared<v1::ConvertLike>(num_elements, data));
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
one = context.mark_node(std::make_shared<v1::ConvertLike>(one, data));
auto mul = context.mark_node(std::make_shared<v1::Multiply>(var, num_elements));
auto n_minus_one = context.mark_node(std::make_shared<v1::Subtract>(num_elements, one));
var = context.mark_node(std::make_shared<v1::Divide>(mul, n_minus_one));
}

ov::OutputVector out_vec;

out_vec.push_back(var);
Expand Down

0 comments on commit 6de206b

Please sign in to comment.