Skip to content

Commit

Permalink
register management; refactoring #1
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Aug 12, 2023
1 parent 5fd247c commit 802e49d
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ jit_power_emitter::jit_power_emitter(dnnl::impl::cpu::aarch64::jit_generator *ho

size_t jit_power_emitter::get_inputs_num() const { return 2; }

size_t jit_power_emitter::aux_vecs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_power_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32, element::f32}};
}
Expand Down Expand Up @@ -229,12 +231,11 @@ void jit_power_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s
TReg src1 = TReg(in_vec_idxs[1]);
h->uni_fcvtzs(src1.s, src1.s);
Xbyak_aarch64::VRegSElem s = src1.s[0];
// TODO: 1?
Xbyak_aarch64::WReg counter{1};

Xbyak_aarch64::WReg counter{aux_gpr_idxs[0]};
h->mov(counter, s);

// TODO: workaround
TReg aux0 = TReg(in_vec_idxs[1] + 1);
TReg aux0 = TReg(aux_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);

Xbyak_aarch64::Label loop_label;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class jit_power_emitter : public jit_emitter {
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;

size_t aux_vecs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
Expand Down
19 changes: 15 additions & 4 deletions src/plugins/intel_cpu/src/emitters/aarch64/jit_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,21 @@ size_t jit_emitter::aux_vecs_count() const {
void jit_emitter::prepare_table() {
}

void jit_emitter::emitter_preamble(const std::vector<size_t> &in_idxs,
const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs,
const std::vector<size_t> &pool_gpr_idxs) const {
void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const {
for (auto idx : pool_vec_idxs) {
aux_vec_idxs.push_back(static_cast<uint32_t>(idx));
}

for (auto idx : pool_gpr_idxs) {
aux_gpr_idxs.push_back(static_cast<uint32_t>(idx));
}

if (aux_vec_idxs.size() < aux_vecs_count()) {
IE_THROW() << "Failed to allocate required number of vector registers";
}
}

void jit_emitter::emitter_postamble() const {
Expand Down
10 changes: 7 additions & 3 deletions src/plugins/intel_cpu/src/emitters/aarch64/jit_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class jit_emitter : public ov::snippets::Emitter {
size_t get_max_vecs_count() const;
size_t get_vec_length() const;

mutable std::vector<uint32_t> aux_vec_idxs;
mutable std::vector<uint32_t> aux_gpr_idxs;

dnnl::impl::cpu::aarch64::jit_generator* h;
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa_;
InferenceEngine::Precision exec_prc_;
Expand All @@ -88,9 +91,10 @@ class jit_emitter : public ov::snippets::Emitter {

virtual void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const = 0;

virtual void emitter_preamble(
const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const;
virtual void emitter_preamble(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const;

virtual void emitter_postamble() const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,15 +431,19 @@ void jit_uni_eltwise_generic<isa>::compute_eltwise_op() {
std::vector<size_t> out_idxs;
out_idxs.push_back(vmm_dst.getIdx());

eltwise_emitter->emit_code(in_idxs, out_idxs, aux_idxs);
std::vector<size_t> gpr_idxs;
// TODO: not completed
gpr_idxs.push_back(12);
gpr_idxs.push_back(13);

eltwise_emitter->emit_code(in_idxs, out_idxs, aux_idxs, gpr_idxs);
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::apply_post_ops() {
int input_idx = eltwise_emitter->get_inputs_num();
int eltwise_post_op_idx = 0;
for (size_t i = 1; i < ops_list_.size(); i++) {
// TODO: FakeQuantize is not supported
if (ops_list_[i] == ov::intel_cpu::Type::Eltwise) {
std::vector<size_t> in_idxs;
std::vector<size_t> aux_idxs;
Expand All @@ -456,6 +460,7 @@ void jit_uni_eltwise_generic<isa>::apply_post_ops() {

eltwise_post_op_idx++;
} else {
// TODO: FakeQuantize is not supported
IE_THROW(Unexpected) << "Eltwise jit kernel: unexpected operation type";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,28 +128,28 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
Xbyak_aarch64::XReg reg_dst = x10;
TReg vmm_dst {9};

inline XReg get_src_reg(int idx) {
inline XReg get_src_reg(uint32_t idx) {
if (idx > MAX_ELTWISE_INPUTS) {
IE_THROW(Unexpected) << "source vector ptr register " << idx << " is not supported";
}
return XReg(11 + idx);
}

inline TReg get_vmm_reg(int idx) {
inline TReg get_vmm_reg(uint32_t idx) {
if (idx > MAX_ELTWISE_INPUTS) {
IE_THROW(Unexpected) << "source vector register " << idx << " is not supported";
}
return TReg(1 + idx);
}

inline SReg get_scl_reg(int idx) {
inline SReg get_scl_reg(uint32_t idx) {
if (idx > MAX_ELTWISE_INPUTS) {
IE_THROW(Unexpected) << "source scalar register " << idx << " is not supported";
}
return SReg(1 + idx);
}

inline TReg get_aux_vmm(int idx) {
inline TReg get_aux_vmm(uint32_t idx) {
if (idx > 2) {
IE_THROW(Unexpected) << "aux vector register " << idx << " is not supported";
}
Expand Down

0 comments on commit 802e49d

Please sign in to comment.