diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp index 985b3d3cc3d580..ae4cda0e13ce8c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp @@ -41,10 +41,9 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) { - OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported input type: ", src_prc.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported output type: ", dst_prc.get_type_name()); + bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + (src_prc == dst_prc || one_of(dst_prc, ov::element::f32, ov::element::i32)); + OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); const auto load = std::dynamic_pointer_cast(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression"); @@ -103,10 +102,9 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const s jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) { - OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported input type: ", src_prc.get_type_name()); - OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8), - "Unsupported output type: ", dst_prc.get_type_name()); + bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) && + (src_prc == dst_prc || one_of(src_prc, ov::element::f32, ov::element::i32)); + OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair."); if (ov::is_type(expr->get_node())) { store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation));