Skip to content

Commit

Permalink
table was added
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 27, 2023
1 parent 609cac7 commit 7d82865
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,38 @@ jit_power_emitter::jit_power_emitter(dnnl::impl::cpu::aarch64::jit_generator *ho
const float power,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)), power(power) {
auto powerStaticNode = ov::as_type_ptr<ov::snippets::op::PowerStatic>(node);
if (powerStaticNode == nullptr) {
IE_THROW() << "Can't cast to snippets::op::PowerStatic";
}

// scale = 1.f;
// shift = 0.f;

prepare_table();
}

jit_power_emitter::jit_power_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const float power,
const Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc), power(power) {
prepare_table();
}

size_t jit_power_emitter::get_inputs_count() const { return 1; }

size_t jit_power_emitter::get_aux_vecs_count() const { return 1; }
size_t jit_power_emitter::get_aux_vecs_count() const { return 2; }

size_t jit_power_emitter::get_aux_gprs_count() const { return 1; }

void jit_power_emitter::register_table_entries() {
push_arg_entry_of("power", dnnl::impl::float2int(power), true);
// push_arg_entry_of("scale", float2int(scale), true);
// push_arg_entry_of("shift", float2int(shift), true);
// push_arg_entry_of("one", float2int(1.f), true);
}

std::set<std::vector<element::Type>> jit_power_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32, element::f32}};
}
Expand Down Expand Up @@ -249,7 +266,11 @@ void jit_power_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s

for (auto i = 0; i < 4; i++) {
h->mov(s0, src.s[i]);
h->fmov(s1, power);

//const float power2 = 1.23;
//h->fmov(s1, power2);
h->ldr(s1, table_val("power"));

h->blr(x8);

Xbyak_aarch64::WReg w0(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class jit_multiply_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

// TODO: jit_power_emitter => jit_power_static_emitter
class jit_power_emitter : public jit_emitter {
public:
jit_power_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand All @@ -97,6 +98,8 @@ class jit_power_emitter : public jit_emitter {

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
Expand Down
38 changes: 38 additions & 0 deletions src/plugins/intel_cpu/src/emitters/aarch64/jit_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ void jit_emitter::emit_code(const std::vector<size_t> &in_idxs,
emitter_postamble();
}

void jit_emitter::emit_data() const {
h->align(64);
h->L(*l_table.get());

// Assumption: entries can be inserted with dd, so they should be 4 bytes.
assert(sizeof(table_entry_val_t) == 4);

// Run through the map and insert values stored there
for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) {
const auto &te = (*it).second; // get map entry for a given key
const auto len = te.bcast ? get_vec_length() : sizeof(table_entry_val_t);
for (size_t d = 0; d < len; d += sizeof(table_entry_val_t))
h->dd(te.val);
}
}

std::set<std::vector<element::Type>> jit_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {};
}
Expand All @@ -45,6 +61,18 @@ size_t jit_emitter::get_aux_vecs_count() const {
}

void jit_emitter::prepare_table() {
register_table_entries();

// Now that we registered the entries, we set the offsets. No
// entries should be registered after this point. This allows to
// expect the same order when injecting the table entries in
// prepare_table.
size_t off = 0;
for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) {
auto &te = (*it).second;
te.off = off;
off += te.bcast ? get_vec_length() : sizeof(table_entry_val_t);
}
}

void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
Expand All @@ -66,6 +94,16 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
for (auto idx : pool_aux_gpr_idxs) {
aux_gpr_idxs.push_back(static_cast<uint32_t>(idx));
}

if (!entry_map_.empty()) {
// last aux_gpr_idx is for p_table, we can use aux_gpr_idxs from idx 0 for other purpose
p_table = Xbyak_aarch64::XReg(aux_gpr_idxs[aux_gpr_idxs.size() - 1]);
aux_gpr_idxs.erase(aux_gpr_idxs.end() - 1);
}

if (!entry_map_.empty()) {
load_table_addr();
}
}

void jit_emitter::emitter_postamble() const {
Expand Down
68 changes: 66 additions & 2 deletions src/plugins/intel_cpu/src/emitters/aarch64/jit_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class jit_emitter : public ov::snippets::Emitter {
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32,
const float alpha = 0.f,
emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) :
Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc), alpha(alpha), in_out_type_(in_out_type) {
Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc),
alpha(alpha), in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) {
}

jit_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand All @@ -47,7 +48,8 @@ class jit_emitter : public ov::snippets::Emitter {
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32,
const float alpha = 0.f,
emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) :
Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc), alpha(alpha), in_out_type_(in_out_type) {
Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc),
alpha(alpha), in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) {
}

void emit_code(
Expand All @@ -56,6 +58,8 @@ class jit_emitter : public ov::snippets::Emitter {
const std::vector<size_t> &pool_vec_idxs = {},
const std::vector<size_t> &pool_gpr_idxs = {}) const override;

void emit_data() const override;

virtual size_t get_inputs_count() const = 0;
virtual size_t get_aux_vecs_count() const;
virtual size_t get_aux_gprs_count() const;
Expand Down Expand Up @@ -84,6 +88,26 @@ class jit_emitter : public ov::snippets::Emitter {
virtual void prepare_table();
virtual void register_table_entries() {}

void load_table_addr() const { h->adr(p_table, *l_table.get()); }

// we accept only 32bit hexadecimal table values to avoid any rounding
using table_entry_val_t = uint32_t;
using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table
using table_entry_bcast_t = bool; // true => bcast value

struct table_entry_t {
table_entry_val_t val;
table_entry_bcast_t bcast;
};
struct mapped_table_entry_t {
table_entry_offset_t off;
table_entry_val_t val;
table_entry_bcast_t bcast;
};

mutable Xbyak_aarch64::XReg p_table;
mutable std::shared_ptr<Xbyak_aarch64::Label> l_table;

virtual void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const = 0;

virtual void emitter_preamble(const std::vector<size_t>& in_idxs,
Expand All @@ -93,9 +117,49 @@ class jit_emitter : public ov::snippets::Emitter {

virtual void emitter_postamble() const;

// XReg table_val(std::string key, size_t key_off_val_shift = 0) const {
// auto off = table_off(key, key_off_val_shift);
// return h->ptr[p_table + off];
// }

using table_t = std::multimap<std::string, table_entry_t>;
using mapped_table_t = std::multimap<std::string, mapped_table_entry_t>;

mapped_table_t entry_map_;

Xbyak_aarch64::AdrImm table_val(std::string key, size_t key_off_val_shift = 0) const {
//auto off = table_off(key, key_off_val_shift);
int32_t off = table_off(key, key_off_val_shift);
return Xbyak_aarch64::ptr(p_table, off);
}

void push_arg_entry_of(const std::string key, const table_entry_val_t val, const bool broadcast) {
mapped_table_entry_t te {0, val, broadcast};
entry_map_.insert(std::make_pair(key, te));
}

void push_entries_of(const table_t &t) {
for (auto it = t.begin(); it != t.end(); it++) {
auto key = (*it).first;
auto te = (*it).second; // copy values from table
push_arg_entry_of(key, te.val, te.bcast);
}
}

private:
mutable std::vector<size_t> preserved_vec_idxs;
mutable std::vector<size_t> preserved_gpr_idxs;

size_t table_off(std::string& key, size_t key_off_val_shift = 0) const {
// assumption: all table entries sharing the same key also
// share their broadcast property
// TODO: enforce through data structure
const auto it = entry_map_.find(key); // search an entry for a key
assert(it != entry_map_.end());
const auto &te = (*it).second;
const auto scale = te.bcast ? get_vec_length() : sizeof(table_entry_val_t);
return te.off + key_off_val_shift * scale;
}
};

} // namespace aarch64
Expand Down

0 comments on commit 7d82865

Please sign in to comment.