diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp index 7462cad00b8..86943108a98 100644 --- a/src/cpu/gemm_convolution_utils.cpp +++ b/src/cpu/gemm_convolution_utils.cpp @@ -769,13 +769,14 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp, jcp.outer_threading = false; - bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8) - && weights_d.data_type() == s8; - const bool is_bwd_d = jcp.prop_kind == backward_data; const bool is_bwd_w = jcp.prop_kind == backward_weights; const bool is_fwd = !is_bwd_d && !is_bwd_w; + bool is_int8_conv = (is_fwd ? utils::one_of(src_d.data_type(), s8, u8) + : utils::one_of(dst_d.data_type(), s8, u8)) + && weights_d.data_type() == s8; + bool is_bf16_conv = false || (is_fwd && utils::everyone_is(