Skip to content

Commit

Permalink
Update precision assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchen-intel committed Aug 8, 2024
1 parent ad2fea7 commit 605bdb5
Showing 1 changed file with 6 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex

jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) {
OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported input type: ", src_prc.get_type_name());
OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported output type: ", dst_prc.get_type_name());
bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
(src_prc == dst_prc || one_of(dst_prc, ov::element::f32, ov::element::i32));
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");

const auto load = std::dynamic_pointer_cast<snippets::op::Load>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression");
Expand Down Expand Up @@ -103,10 +102,9 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector<size_t> &in, const s

jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) {
OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported input type: ", src_prc.get_type_name());
OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported output type: ", dst_prc.get_type_name());
bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
(src_prc == dst_prc || one_of(src_prc, ov::element::f32, ov::element::i32));
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");

if (ov::is_type<ov::intel_cpu::StoreConvertTruncation>(expr->get_node())) {
store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation));
Expand Down

0 comments on commit 605bdb5

Please sign in to comment.