Skip to content

Commit

Permalink
Merge pull request #35 from ynimmaga/export_quantization
Browse files Browse the repository at this point in the history
Added support for quantized fx ops for export path
  • Loading branch information
ynimmaga authored Mar 29, 2024
2 parents 2e948f2 + 1424ca4 commit 249ff70
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, options):
"torch.ops.aten.argmax.default": None,
"torch.ops.aten.argmin.default": None,
"torch.ops.aten.as_strided.default": None,
"torch.ops.aten.as_strided_.default": None,
"torch.ops.aten.asin.default": None,
"torch.ops.aten.asinh.default": None,
"torch.ops.aten.asinh.default": None,
Expand Down Expand Up @@ -246,6 +247,11 @@ def __init__(self, options):
"torch.ops.aten.zeros_like.default": None,
"torch.ops.torchvision.deform_conv2d.default": None,
"torch.ops.torchvision.roi_align.default": None,
"torch.ops.quantized_decomposed.quantize_per_tensor.default": None,
"torch.ops.quantized_decomposed.quantize_per_channel.default": None,
"torch.ops.quantized_decomposed.dequantize_per_tensor.default": None,
"torch.ops.quantized_decomposed.dequantize_per_channel.default": None

}

for op in _get_disabled_ops(options):
Expand Down
21 changes: 21 additions & 0 deletions src/frontends/pytorch/src/op/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ OutputVector translate_quantize_per_channel(const NodeContext& context) {
return {quantize(context, input, scales, zero_points, axis, dtype, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)};
}

OutputVector translate_quantize_per_tensor_fx(const NodeContext& context) {
num_inputs_check(context, 4, 8);
const auto input = context.get_input(0);
const auto scale = context.get_input(1);
const auto zero_point = context.get_input(2);
auto low = context.const_input<int64_t>(3);
auto high = context.const_input<int64_t>(4);
return {quantize_fx(context, input, scale, zero_point, low, high, element::i8, QuantizedPtNodeType::QUANTIZE_PER_TENSOR)};
}

OutputVector translate_quantize_per_channel_fx(const NodeContext& context) {
num_inputs_check(context, 4, 8);
const auto input = context.get_input(0);
const auto scales = context.get_input(1);
const auto zero_points = context.get_input(2);
const auto axis = context.get_input(3);
auto low = context.const_input<int64_t>(4);
auto high = context.const_input<int64_t>(5);
return {quantize_fx(context, input, scales, zero_points, axis, low, high, element::i8, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)};
}

OutputVector translate_fake_quantize_per_tensor_affine_fx(const NodeContext& context) {
num_inputs_check(context, 6, 6);
auto out = translate_quantize_per_tensor(context);
Expand Down
7 changes: 7 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ OP_CONVERTER(translate_std_fx);
OP_CONVERTER(translate_topk_fx);
OP_CONVERTER(translate_to_fx);
OP_CONVERTER(translate_transpose_fx);
OP_CONVERTER(translate_quantize_per_channel_fx);
OP_CONVERTER(translate_quantize_per_tensor_fx);
OP_CONVERTER(translate_var_fx);
OP_CONVERTER(translate_var_mean_fx);
OP_CONVERTER(translate_unbind_int_fx);
Expand Down Expand Up @@ -760,6 +762,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.argmax.default", op::translate_argmax},
{"aten.argmin.default", op::translate_argmin},
{"aten.as_strided.default", op::translate_as_strided},
{"aten.as_strided_.default", op::translate_as_strided},
{"aten.asin.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asin>},
{"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
{"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
Expand Down Expand Up @@ -953,6 +956,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"get_attr", op::translate_constant},
{"torchvision.deform_conv2d.default", op::translate_deform_conv},
{"torchvision.roi_align.default", op::translate_roi_align},
{"quantized_decomposed.quantize_per_tensor.default", op::translate_quantize_per_tensor_fx},
{"quantized_decomposed.quantize_per_channel.default", op::translate_quantize_per_channel_fx},
{"quantized_decomposed.dequantize_per_tensor.default", op::skip_node},
{"quantized_decomposed.dequantize_per_channel.default", op::skip_node},
};
};

Expand Down
110 changes: 87 additions & 23 deletions src/frontends/pytorch/src/utils_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,20 @@ namespace pytorch {

using namespace ov::op;

Output<Node> quantize(const NodeContext& context,
Output<Node> quantize_common(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
int64_t out_low_i64,
int64_t out_high_i64,
element::Type dtype,
QuantizedPtNodeType quantization_type) {
if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
const auto scale_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
const auto zero_point_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32));

int64_t out_low_i64, out_high_i64;
if (dtype == element::u8) {
out_low_i64 = (int64_t)std::numeric_limits<unsigned char>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<unsigned char>::max();
} else if (dtype == element::i8) {
out_low_i64 = (int64_t)std::numeric_limits<char>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<char>::max();
} else { // i32
out_low_i64 = (int64_t)std::numeric_limits<int>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<int>::max();
}
int64_t levels = out_high_i64 - out_low_i64 + 1;
const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64}));
const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64}));
Expand Down Expand Up @@ -75,17 +66,6 @@ Output<Node> quantize(const NodeContext& context,
const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));

int64_t out_low_i64, out_high_i64;
if (dtype == element::u8) {
out_low_i64 = (int64_t)std::numeric_limits<unsigned char>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<unsigned char>::max();
} else if (dtype == element::i8) {
out_low_i64 = (int64_t)std::numeric_limits<char>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<char>::max();
} else { // i32
out_low_i64 = (int64_t)std::numeric_limits<int>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<int>::max();
}
int64_t levels = out_high_i64 - out_low_i64 + 1;
const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64}));
const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64}));
Expand Down Expand Up @@ -120,6 +100,27 @@ Output<Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize.");
}

Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
element::Type dtype,
QuantizedPtNodeType quantization_type) {
int64_t out_low_i64, out_high_i64;
if (dtype == element::u8) {
out_low_i64 = (int64_t)std::numeric_limits<unsigned char>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<unsigned char>::max();
} else if (dtype == element::i8) {
out_low_i64 = (int64_t)std::numeric_limits<char>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<char>::max();
} else { // i32
out_low_i64 = (int64_t)std::numeric_limits<int>::lowest();
out_high_i64 = (int64_t)std::numeric_limits<int>::max();
}
return quantize_common(context, input, scale, zero_point, axis, out_low_i64, out_high_i64, dtype, quantization_type);
}

Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
Expand Down Expand Up @@ -159,6 +160,69 @@ Output<Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
}

Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
int64_t out_low_i64,
int64_t out_high_i64,
element::Type dtype,
QuantizedPtNodeType quantization_type) {
return quantize_common(context, input, scale, zero_point, axis, out_low_i64, out_high_i64, dtype, quantization_type);
}

Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
int64_t out_low_i64,
int64_t out_high_i64,
element::Type dtype,
QuantizedPtNodeType quantization_type) {
return quantize_fx(context, input, scale, zero_point, Output<Node>(), out_low_i64, out_high_i64, dtype, quantization_type);
}

Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
int64_t out_low_i64,
int64_t out_high_i64,
const Output<Node>& quantized_node) {
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
return quantize_fx(context,
input,
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_axis(),
out_low_i64,
out_high_i64,
quantized_pt_node->get_dtype(),
quantized_pt_node->get_type());
}
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
}

Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
int64_t out_low_i64,
int64_t out_high_i64,
const Output<Node>& quantized_node) {
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
return quantize_fx(context,
input,
scale,
zero_point,
quantized_pt_node->get_axis(),
out_low_i64,
out_high_i64,
quantized_pt_node->get_dtype(),
quantized_pt_node->get_type());
}
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
}

std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(std::shared_ptr<Node> node) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node);
if (!quant_node) {
Expand Down
43 changes: 43 additions & 0 deletions src/frontends/pytorch/src/utils_quantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,49 @@ Output<Node> quantize(const NodeContext& context,
const Output<Node>& zero_point,
const Output<Node>& quantized_node);

/**
* Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode.
*/
Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
int64_t out_low_i64,
int64_t out_high_i64,
element::Type dtype,
QuantizedPtNodeType quantization_type);
Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
int64_t out_low_i64,
int64_t out_high_i64,
element::Type dtype,
QuantizedPtNodeType quantization_type);

/**
* Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode.
*/
Output<Node> quantize_fx(const NodeContext& context,
Output<Node> input,
int64_t out_low_i64,
int64_t out_high_i64,
Output<Node> quantized_node);

/**
* Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to
* the new QuantizedPtNode.
*/
Output<Node> quantize_fx(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
int64_t out_low_i64,
int64_t out_high_i64,
const Output<Node>& quantized_node);


std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(std::shared_ptr<Node> node);

namespace op {
Expand Down

0 comments on commit 249ff70

Please sign in to comment.