-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for quantized fx ops for export path #35
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -32,6 +32,27 @@ OutputVector translate_quantize_per_channel(const NodeContext& context) { | |||||||||||||||||||||
return {quantize(context, input, scales, zero_points, axis, dtype, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)}; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
OutputVector translate_quantize_per_tensor_fx(const NodeContext& context) { | ||||||||||||||||||||||
num_inputs_check(context, 4, 8); | ||||||||||||||||||||||
const auto input = context.get_input(0); | ||||||||||||||||||||||
const auto scale = context.get_input(1); | ||||||||||||||||||||||
const auto zero_point = context.get_input(2); | ||||||||||||||||||||||
auto low = context.const_input<int64_t>(3); | ||||||||||||||||||||||
auto high = context.const_input<int64_t>(4); | ||||||||||||||||||||||
return {quantize_fx(context, input, scale, zero_point, low, high, element::i8, QuantizedPtNodeType::QUANTIZE_PER_TENSOR)}; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
OutputVector translate_quantize_per_channel_fx(const NodeContext& context) { | ||||||||||||||||||||||
num_inputs_check(context, 4, 8); | ||||||||||||||||||||||
const auto input = context.get_input(0); | ||||||||||||||||||||||
const auto scales = context.get_input(1); | ||||||||||||||||||||||
const auto zero_points = context.get_input(2); | ||||||||||||||||||||||
const auto axis = context.get_input(3); | ||||||||||||||||||||||
auto low = context.const_input<int64_t>(4); | ||||||||||||||||||||||
auto high = context.const_input<int64_t>(5); | ||||||||||||||||||||||
return {quantize_fx(context, input, scales, zero_points, axis, low, high, element::i8, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)}; | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [reviewdog-suggester] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
OutputVector translate_fake_quantize_per_tensor_affine_fx(const NodeContext& context) { | ||||||||||||||||||||||
num_inputs_check(context, 6, 6); | ||||||||||||||||||||||
auto out = translate_quantize_per_tensor(context); | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,29 +23,20 @@ namespace pytorch { | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
using namespace ov::op; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
Output<Node> quantize_common(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
const Output<Node>& scale, | ||||||||||||||||||||||||||||||||||
const Output<Node>& zero_point, | ||||||||||||||||||||||||||||||||||
const Output<Node>& axis, | ||||||||||||||||||||||||||||||||||
int64_t out_low_i64, | ||||||||||||||||||||||||||||||||||
int64_t out_high_i64, | ||||||||||||||||||||||||||||||||||
element::Type dtype, | ||||||||||||||||||||||||||||||||||
QuantizedPtNodeType quantization_type) { | ||||||||||||||||||||||||||||||||||
Comment on lines
27
to
34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [reviewdog-suggester] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { | ||||||||||||||||||||||||||||||||||
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32)); | ||||||||||||||||||||||||||||||||||
const auto scale_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32)); | ||||||||||||||||||||||||||||||||||
const auto zero_point_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32)); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
int64_t out_low_i64, out_high_i64; | ||||||||||||||||||||||||||||||||||
if (dtype == element::u8) { | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<unsigned char>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<unsigned char>::max(); | ||||||||||||||||||||||||||||||||||
} else if (dtype == element::i8) { | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<char>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<char>::max(); | ||||||||||||||||||||||||||||||||||
} else { // i32 | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<int>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<int>::max(); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
int64_t levels = out_high_i64 - out_low_i64 + 1; | ||||||||||||||||||||||||||||||||||
const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); | ||||||||||||||||||||||||||||||||||
const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); | ||||||||||||||||||||||||||||||||||
|
@@ -75,17 +66,6 @@ Output<Node> quantize(const NodeContext& context, | |||||||||||||||||||||||||||||||||
const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); | ||||||||||||||||||||||||||||||||||
const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
int64_t out_low_i64, out_high_i64; | ||||||||||||||||||||||||||||||||||
if (dtype == element::u8) { | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<unsigned char>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<unsigned char>::max(); | ||||||||||||||||||||||||||||||||||
} else if (dtype == element::i8) { | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<char>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<char>::max(); | ||||||||||||||||||||||||||||||||||
} else { // i32 | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<int>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<int>::max(); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
int64_t levels = out_high_i64 - out_low_i64 + 1; | ||||||||||||||||||||||||||||||||||
const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); | ||||||||||||||||||||||||||||||||||
const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); | ||||||||||||||||||||||||||||||||||
|
@@ -120,6 +100,27 @@ Output<Node> quantize(const NodeContext& context, | |||||||||||||||||||||||||||||||||
FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize."); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
const Output<Node>& scale, | ||||||||||||||||||||||||||||||||||
const Output<Node>& zero_point, | ||||||||||||||||||||||||||||||||||
const Output<Node>& axis, | ||||||||||||||||||||||||||||||||||
element::Type dtype, | ||||||||||||||||||||||||||||||||||
QuantizedPtNodeType quantization_type) { | ||||||||||||||||||||||||||||||||||
int64_t out_low_i64, out_high_i64; | ||||||||||||||||||||||||||||||||||
if (dtype == element::u8) { | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<unsigned char>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<unsigned char>::max(); | ||||||||||||||||||||||||||||||||||
} else if (dtype == element::i8) { | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<char>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<char>::max(); | ||||||||||||||||||||||||||||||||||
} else { // i32 | ||||||||||||||||||||||||||||||||||
out_low_i64 = (int64_t)std::numeric_limits<int>::lowest(); | ||||||||||||||||||||||||||||||||||
out_high_i64 = (int64_t)std::numeric_limits<int>::max(); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
return quantize_common(context, input, scale, zero_point, axis, out_low_i64, out_high_i64, dtype, quantization_type); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [reviewdog-suggester] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
const Output<Node>& scale, | ||||||||||||||||||||||||||||||||||
|
@@ -159,6 +160,69 @@ Output<Node> quantize(const NodeContext& context, | |||||||||||||||||||||||||||||||||
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize_fx(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
const Output<Node>& scale, | ||||||||||||||||||||||||||||||||||
const Output<Node>& zero_point, | ||||||||||||||||||||||||||||||||||
const Output<Node>& axis, | ||||||||||||||||||||||||||||||||||
int64_t out_low_i64, | ||||||||||||||||||||||||||||||||||
int64_t out_high_i64, | ||||||||||||||||||||||||||||||||||
element::Type dtype, | ||||||||||||||||||||||||||||||||||
QuantizedPtNodeType quantization_type) { | ||||||||||||||||||||||||||||||||||
return quantize_common(context, input, scale, zero_point, axis, out_low_i64, out_high_i64, dtype, quantization_type); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [reviewdog-suggester] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize_fx(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
const Output<Node>& scale, | ||||||||||||||||||||||||||||||||||
const Output<Node>& zero_point, | ||||||||||||||||||||||||||||||||||
int64_t out_low_i64, | ||||||||||||||||||||||||||||||||||
int64_t out_high_i64, | ||||||||||||||||||||||||||||||||||
element::Type dtype, | ||||||||||||||||||||||||||||||||||
QuantizedPtNodeType quantization_type) { | ||||||||||||||||||||||||||||||||||
return quantize_fx(context, input, scale, zero_point, Output<Node>(), out_low_i64, out_high_i64, dtype, quantization_type); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [reviewdog-suggester] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize_fx(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
int64_t out_low_i64, | ||||||||||||||||||||||||||||||||||
int64_t out_high_i64, | ||||||||||||||||||||||||||||||||||
const Output<Node>& quantized_node) { | ||||||||||||||||||||||||||||||||||
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) { | ||||||||||||||||||||||||||||||||||
return quantize_fx(context, | ||||||||||||||||||||||||||||||||||
input, | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_scale(), | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_zero_point(), | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_axis(), | ||||||||||||||||||||||||||||||||||
out_low_i64, | ||||||||||||||||||||||||||||||||||
out_high_i64, | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_dtype(), | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_type()); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Output<Node> quantize_fx(const NodeContext& context, | ||||||||||||||||||||||||||||||||||
const Output<Node>& input, | ||||||||||||||||||||||||||||||||||
const Output<Node>& scale, | ||||||||||||||||||||||||||||||||||
const Output<Node>& zero_point, | ||||||||||||||||||||||||||||||||||
int64_t out_low_i64, | ||||||||||||||||||||||||||||||||||
int64_t out_high_i64, | ||||||||||||||||||||||||||||||||||
const Output<Node>& quantized_node) { | ||||||||||||||||||||||||||||||||||
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) { | ||||||||||||||||||||||||||||||||||
return quantize_fx(context, | ||||||||||||||||||||||||||||||||||
input, | ||||||||||||||||||||||||||||||||||
scale, | ||||||||||||||||||||||||||||||||||
zero_point, | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_axis(), | ||||||||||||||||||||||||||||||||||
out_low_i64, | ||||||||||||||||||||||||||||||||||
out_high_i64, | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_dtype(), | ||||||||||||||||||||||||||||||||||
quantized_pt_node->get_type()); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(std::shared_ptr<Node> node) { | ||||||||||||||||||||||||||||||||||
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node); | ||||||||||||||||||||||||||||||||||
if (!quant_node) { | ||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -144,6 +144,49 @@ Output<Node> quantize(const NodeContext& context, | |||
const Output<Node>& zero_point, | ||||
const Output<Node>& quantized_node); | ||||
|
||||
/** | ||||
* Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode. | ||||
*/ | ||||
Output<Node> quantize_fx(const NodeContext& context, | ||||
const Output<Node>& input, | ||||
const Output<Node>& scale, | ||||
const Output<Node>& zero_point, | ||||
int64_t out_low_i64, | ||||
int64_t out_high_i64, | ||||
element::Type dtype, | ||||
QuantizedPtNodeType quantization_type); | ||||
Output<Node> quantize_fx(const NodeContext& context, | ||||
const Output<Node>& input, | ||||
const Output<Node>& scale, | ||||
const Output<Node>& zero_point, | ||||
const Output<Node>& axis, | ||||
int64_t out_low_i64, | ||||
int64_t out_high_i64, | ||||
element::Type dtype, | ||||
QuantizedPtNodeType quantization_type); | ||||
|
||||
/** | ||||
* Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode. | ||||
*/ | ||||
Output<Node> quantize_fx(const NodeContext& context, | ||||
Output<Node> input, | ||||
int64_t out_low_i64, | ||||
int64_t out_high_i64, | ||||
Output<Node> quantized_node); | ||||
|
||||
/** | ||||
* Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to | ||||
* the new QuantizedPtNode. | ||||
*/ | ||||
Output<Node> quantize_fx(const NodeContext& context, | ||||
const Output<Node>& input, | ||||
const Output<Node>& scale, | ||||
const Output<Node>& zero_point, | ||||
int64_t out_low_i64, | ||||
int64_t out_high_i64, | ||||
const Output<Node>& quantized_node); | ||||
|
||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [reviewdog-suggester] reported by reviewdog 🐶
Suggested change
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(std::shared_ptr<Node> node); | ||||
|
||||
namespace op { | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[reviewdog-suggester] reported by reviewdog 🐶