Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Pattern matching for uniform_quantize/dequantize ops with convolution custom calls #1

Open
wants to merge 1 commit into
base: akhil/tmp_conv_fusions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,24 +910,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
? module->config().intra_op_parallelism_threads()
: tsl::port::NumSchedulableCPUs();

#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
// AOT compiled code runs in single thread.
if (!is_aot_compile) {
// Run SimplifyFPConversions pass to simplify the BF16 pattern and make it
// easier to match.
pipeline.AddPass<SimplifyFPConversions>();
pipeline.AddPass<OneDnnConvolutionRewriter>();
pipeline.AddPass<OneDnnMatMulRewriter>(max_parallelism,
compile_options.thread_pool);
// Run SimplifyFPConversions pass again to remove redundant Convert ops
// that may exist as a result of running OneDnnMatMulRewriter pass.
pipeline.AddPass<SimplifyFPConversions>();
}
#endif // INTEL_MKL && ENABLE_ONEDNN_V3

// Add a fusion pass now that layout assignment is done.
pipeline.AddPass<CpuInstructionFusion>();

// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
// Run this to a fixed point.
Expand All @@ -951,6 +933,24 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}();

#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
// AOT compiled code runs in single thread.
if (!is_aot_compile) {
// Run SimplifyFPConversions pass to simplify the BF16 pattern and make it
// easier to match.
pipeline.AddPass<SimplifyFPConversions>();
pipeline.AddPass<OneDnnConvolutionRewriter>();
pipeline.AddPass<OneDnnMatMulRewriter>(max_parallelism,
compile_options.thread_pool);
// Run SimplifyFPConversions pass again to remove redundant Convert ops
// that may exist as a result of running OneDnnMatMulRewriter pass.
pipeline.AddPass<SimplifyFPConversions>();
}
#endif // INTEL_MKL && ENABLE_ONEDNN_V3

// Add a fusion pass now that layout assignment is done.
pipeline.AddPass<CpuInstructionFusion>();

// Outline ops in the entry computation into calls to subcomputations.
if (!is_aot_compile) {
// Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
Expand Down
108 changes: 105 additions & 3 deletions xla/service/cpu/onednn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution(
MemrefInfo ker_minfo(args[arg_indx++]);
MemrefInfo res_minfo(result);

memory::data_type res_dt = res_minfo.GetOneDnnDataType();
bool quant_result =
(res_dt == memory::data_type::s8 || res_dt == memory::data_type::u8);
bool quant_operands = ker_minfo.GetOneDnnDataType() == memory::data_type::s8;
memory::data_type inp_dt = inp_minfo.GetOneDnnDataType();
if (quant_operands) {
// Hybrid quantization is not currently supported.
XLA_LIGHTWEIGHT_CHECK(inp_dt == memory::data_type::s8 ||
inp_dt == memory::data_type::u8);
}

// Permute memory descriptors
auto inp_md = inp_minfo.GetOneDnnMemDesc();
auto ker_md = ker_minfo.GetOneDnnMemDesc();
Expand Down Expand Up @@ -200,14 +211,94 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution(
<< std::endl;
}
}

XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx);

dnnl::primitive_attr attrs;
if (post_ops.len() > 0) {
attrs.set_post_ops(post_ops);
attrs.set_post_ops(post_ops);
}

auto src_scale_mem = memory(nullptr);
auto src_zp_mem = memory(nullptr);
auto wei_scale_mem = memory(nullptr);
auto wei_zp_mem = memory(nullptr);
auto dst_scale_mem = memory(nullptr);
auto dst_zp_mem = memory(nullptr);

std::vector<int> src_zp_vec(1);
std::vector<int> dst_zp_vec(1);
std::vector<float> dst_scale_vec(1);

if (quant_operands) {
MemrefInfo src_scale_minfo(args[arg_indx++]);
MemrefInfo src_zp_minfo(args[arg_indx++]);
MemrefInfo wei_scale_minfo(args[arg_indx++]);
MemrefInfo wei_zp_minfo(args[arg_indx++]);

auto src_scale_md = src_scale_minfo.GetOneDnnMemDesc();
auto src_zp_md = src_zp_minfo.GetOneDnnMemDesc();
auto wei_scale_md = wei_scale_minfo.GetOneDnnMemDesc();
int wei_scale_size = wei_scale_md.get_dims()[0];
auto wei_zp_md = wei_zp_minfo.GetOneDnnMemDesc();

// oneDNN only supports common scale/zp for src (no per-channel support).
XLA_LIGHTWEIGHT_CHECK(src_scale_md.get_dims()[0] == 1);
XLA_LIGHTWEIGHT_CHECK(src_zp_md.get_dims()[0] ==
src_scale_md.get_dims()[0]);

src_scale_mem = memory(src_scale_md, cpu_engine, src_scale_minfo.Data());
int* src_zp_data = (int*)src_zp_minfo.Data();

// We need to negate the sign of the zp to get the original one because the
// hlo optimizer flips the zp sign in uniform_dequantize pattern.
// TODO (intel-tf): we need to do that based on some flag passed from the
// rewriter.
src_zp_vec[0] = src_zp_data[0] * -1;
src_zp_mem = memory(src_zp_md, cpu_engine, src_zp_vec.data());
wei_scale_mem = memory(wei_scale_md, cpu_engine, wei_scale_minfo.Data());
wei_zp_mem = memory(wei_zp_md, cpu_engine, wei_zp_minfo.Data());

if (quant_result) {
MemrefInfo dst_scale_minfo(args[arg_indx++]);
MemrefInfo dst_zp_minfo(args[arg_indx++]);

auto dst_scale_md = dst_scale_minfo.GetOneDnnMemDesc();
auto dst_zp_md = dst_zp_minfo.GetOneDnnMemDesc();

// oneDNN only supports common scale/zp for dst (no per-channel support).
XLA_LIGHTWEIGHT_CHECK(dst_scale_md.get_dims()[0] == 1);
XLA_LIGHTWEIGHT_CHECK(dst_zp_md.get_dims()[0] ==
dst_scale_md.get_dims()[0]);

float* scale_data = (float*)dst_scale_minfo.Data();
// We need to compute the reciprocal of scale to get the original one
// because the hlo optimizer changes it in uniform_quantize pattern.
// TODO (intel-tf): we need to do that based on some flag passed from the
// rewriter.
dst_scale_vec[0] = 1.0 / scale_data[0];
if (dst_zp_md.get_data_type() == memory::data_type::f32) {
// oneDNN expects zp to be int32 not f32.
dst_zp_vec[0] = static_cast<int>(((float*)dst_zp_minfo.Data())[0]);
} else {
dst_zp_vec[0] = ((int*)dst_zp_minfo.Data())[0];
}
dst_scale_mem = memory(dst_scale_md, cpu_engine, dst_scale_vec.data());
auto dst_zp_md_new = memory::desc(
dst_zp_md.get_dims(), memory::data_type::s32, memory::format_tag::x);
dst_zp_mem = memory(dst_zp_md_new, cpu_engine, dst_zp_vec.data());
}

attrs.set_scales_mask(DNNL_ARG_SRC, 0);
attrs.set_zero_points_mask(DNNL_ARG_SRC, 0);
const int wei_mask = (wei_scale_size == 1) ? 0 : (groups > 1) ? 3 : 1;
attrs.set_scales_mask(DNNL_ARG_WEIGHTS, wei_mask);
if (quant_result) {
attrs.set_scales_mask(DNNL_ARG_DST, 0);
attrs.set_zero_points_mask(DNNL_ARG_DST, 0);
}
}

XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx);

memory::dims strides_dims = strds;
memory::dims padding_dims_l = pad_l;
memory::dims padding_dims_r = pad_r;
Expand Down Expand Up @@ -240,6 +331,17 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution(
{DNNL_ARG_WEIGHTS, new_ker_mem},
{DNNL_ARG_BIAS, bias_mem},
{DNNL_ARG_DST, new_res_mem}};
if (quant_operands) {
conv_args.insert(
{{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scale_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scale_mem}});
if (quant_result) {
conv_args.insert(
{{DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_mem}});
}
}

conv_prim.execute(onednn_stream, conv_args);

Expand Down
Loading
Loading