Skip to content

Commit

Permalink
Call convert_emitter in load/store_emitter
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchen-intel committed Aug 20, 2024
1 parent 3c2b040 commit 52d8898
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 322 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "jit_conversion_emitters.hpp"
#include "emitters/utils.hpp"
#include "utils.hpp"

using namespace dnnl::impl::cpu::aarch64;
using namespace Xbyak_aarch64;
Expand All @@ -13,6 +12,126 @@ namespace ov {
namespace intel_cpu {
namespace aarch64 {

// In aarch64, conversion between f16 and i16/u16 can be done with single instruction. The supported
// conversion precicions are f32, i32, f16, i8 (byte), u8 (byte). If we introduce an intermediate
// precision i16/u16 (dbyte) in the following graph. Then the conversion between each pair of
// neighbors in this graph will be done with single instruction.
// f16 - f32 - i32 - dbyte - byte
// | |
// - - - - - - - - - - -
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_f16_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
h->fcvtl(dst.s4, src.h4);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_f32_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
h->fcvtn(dst.h4, src.s4);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_f32_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
h->fcvtzs(dst.s, src.s);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_i32_to_f32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
h->scvtf(dst.s, src.s);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_i32_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
bool is_signed, bool is_saturated) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
if (is_saturated) {
if (is_signed) {
h->sqxtn(dst.h4, src.s4);
} else {
h->uqxtn(dst.h4, src.s4);
}
} else {
h->xtn(dst.h4, src.s4);
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_dbyte_to_i32(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
bool is_signed) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
if (is_signed) {
h->sxtl(dst.s4, src.h4);
} else {
h->uxtl(dst.s4, src.h4);
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_f16_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
h->fcvtzs(dst.h, src.h);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_dbyte_to_f16(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
bool is_signed) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
if (is_signed) {
h->scvtf(dst.h, src.h);
} else {
h->ucvtf(dst.h, src.h);
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_dbyte_to_byte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
bool is_signed, bool is_saturated) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
if (is_saturated) {
if (is_signed) {
h->sqxtn(dst.b8, src.h8);
} else {
h->uqxtn(dst.b8, src.h8);
}
} else {
h->xtn(dst.b8, src.h8);
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
static void cvt_byte_to_dbyte(dnnl::impl::cpu::aarch64::jit_generator* h, const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
bool is_signed) {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_idxs[0]);
TReg dst = TReg(out_idxs[0]);
if (is_signed) {
h->sxtl(dst.h8, src.b8);
} else {
h->uxtl(dst.h8, src.b8);
}
}

template <cpu_isa_t isa>
static void jit_convert_process(dnnl::impl::cpu::aarch64::jit_generator* h,
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
Expand Down Expand Up @@ -120,6 +239,15 @@ jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa
output_type = node->get_output_element_type(0);
}

jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
input_type = input_prc;
output_type = output_prc;
}

void jit_convert_emitter::validate_types() const {
OV_CPU_JIT_EMITTER_ASSERT(one_of(input_type, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported input type: ", input_type.get_type_name());
Expand All @@ -138,6 +266,13 @@ jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *ho
: jit_convert_emitter(host, host_isa, node, exec_prc) {
}

jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
}

void jit_convert_truncation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
validate_types();
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
Expand All @@ -157,6 +292,13 @@ jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *ho
: jit_convert_emitter(host, host_isa, node, exec_prc) {
}

jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
}

void jit_convert_saturation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
validate_types();
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class jit_convert_emitter : public jit_emitter {
public:
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

size_t get_inputs_count() const override;

Expand All @@ -33,6 +35,8 @@ class jit_convert_truncation_emitter : public jit_convert_emitter {
public:
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

private:
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
Expand All @@ -48,6 +52,8 @@ class jit_convert_saturation_emitter : public jit_convert_emitter {
public:
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

private:
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "jit_load_store_emitters.hpp"
#include "cpu/aarch64/cpu_isa_traits.hpp"
#include "emitters/utils.hpp"
#include "utils.hpp"

using namespace Xbyak_aarch64;

Expand All @@ -20,7 +19,9 @@ jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host
ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, int byte_offset,
ov::element::Type exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset),
src_prc_(src_prc), dst_prc_(dst_prc) {}
src_prc_(src_prc), dst_prc_(dst_prc) {
convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc));
}

void jit_load_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
Expand Down Expand Up @@ -135,68 +136,23 @@ void jit_load_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::v

switch (src_prc_) {
case ov::element::f32:
load_qbyte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
switch (dst_prc_) {
case ov::element::f32:
break;
case ov::element::i32:
cvt_f32_to_i32<isa>(h, aux_vec_idxs, out_idxs);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name());
}
break;
case ov::element::i32:
load_qbyte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
switch (dst_prc_) {
case ov::element::f32:
cvt_i32_to_f32<isa>(h, aux_vec_idxs, out_idxs);
break;
case ov::element::i32:
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name());
}
break;
case ov::element::f16:
load_dbyte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
switch (dst_prc_) {
case ov::element::f32:
cvt_f16_to_f32<isa>(h, aux_vec_idxs, out_idxs);
break;
case ov::element::i32:
cvt_f16_to_f32<isa>(h, aux_vec_idxs, aux_vec_idxs);
cvt_f32_to_i32<isa>(h, aux_vec_idxs, out_idxs);
break;
case ov::element::f16:
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name());
}
break;
case ov::element::i8:
case ov::element::u8:
load_byte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
switch (dst_prc_) {
case ov::element::f32:
cvt_byte_to_dbyte<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
cvt_dbyte_to_i32<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
cvt_i32_to_f32<isa>(h, aux_vec_idxs, out_idxs);
break;
case ov::element::i32:
cvt_byte_to_dbyte<isa>(h, aux_vec_idxs, aux_vec_idxs, src_prc_.is_signed());
cvt_dbyte_to_i32<isa>(h, aux_vec_idxs, out_idxs, src_prc_.is_signed());
break;
case ov::element::i8:
case ov::element::u8:
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name());
}
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name());
}

if (src_prc_ != dst_prc_) {
convert_emitter->emit_code(aux_vec_idxs, out_idxs);
}
}

size_t jit_load_emitter::get_aux_gprs_count() const {
Expand All @@ -217,7 +173,15 @@ jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *ho
ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset,
arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset),
src_prc_(src_prc), dst_prc_(dst_prc), mode_(mode) {}
src_prc_(src_prc), dst_prc_(dst_prc) {
if (mode == arithmetic_mode::truncation) {
convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc));
} else if (mode == arithmetic_mode::saturation) {
convert_emitter.reset(new jit_convert_saturation_emitter(host, host_isa, src_prc, dst_prc, exec_prc));
} else {
OV_CPU_JIT_EMITTER_THROW("Unsupported Convert emitter.");
}
}

void jit_store_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
Expand Down Expand Up @@ -331,65 +295,20 @@ void jit_store_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::
OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= static_cast<int>((get_vec_length() / dst_prc_.size())),
"Unexpected number of elements to store.");

if (src_prc_ != dst_prc_) {
convert_emitter->emit_code(in_idxs, aux_vec_idxs);
}

switch (dst_prc_) {
case ov::element::f32:
switch (src_prc_) {
case ov::element::f32:
break;
case ov::element::i32:
cvt_i32_to_f32<isa>(h, in_idxs, aux_vec_idxs);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name());
}
store_qbyte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
break;
case ov::element::i32:
switch (src_prc_) {
case ov::element::f32:
cvt_f32_to_i32<isa>(h, in_idxs, aux_vec_idxs);
break;
case ov::element::i32:
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name());
}
store_qbyte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
break;
case ov::element::f16:
switch (src_prc_) {
case ov::element::f32:
cvt_f32_to_f16<isa>(h, in_idxs, aux_vec_idxs);
break;
case ov::element::i32:
cvt_i32_to_f32<isa>(h, in_idxs, aux_vec_idxs);
cvt_f32_to_f16<isa>(h, aux_vec_idxs, aux_vec_idxs);
break;
case ov::element::f16:
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name());
}
store_dbyte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
break;
case ov::element::i8:
case ov::element::u8:
switch (src_prc_) {
case ov::element::f32:
cvt_f32_to_i32<isa>(h, in_idxs, aux_vec_idxs);
cvt_i32_to_dbyte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation);
cvt_dbyte_to_byte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation);
break;
case ov::element::i32:
cvt_i32_to_dbyte<isa>(h, in_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation);
cvt_dbyte_to_byte<isa>(h, aux_vec_idxs, aux_vec_idxs, dst_prc_.is_signed(), mode_ == arithmetic_mode::saturation);
break;
case ov::element::i8:
case ov::element::u8:
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name());
}
store_byte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
break;
default:
Expand Down
Loading

0 comments on commit 52d8898

Please sign in to comment.