Skip to content

Commit

Permalink
[TMP] Reorder before execution: first implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 10, 2024
1 parent ed8b6ea commit a1ef9da
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ void RuntimeConfigurator::update_data_offsets(const std::vector<VectorDims>& sha
dim_step *= shape[i + 1];
offsets[i + idx_stride] = shape[i] != 1 ? dim_step : 0;
}
// TODO: remove this hardcode
if (!std::getenv("REFERENCE") && i == 1)
offsets[3] = 2048 * 2;
std::cout << "offsets[" << i << "] = " << ov::PartialShape(offsets) << std::endl;
if (!layout.empty()) {
std::vector<size_t> reordered_offsets(offsets.size());
const auto is_input = i < m_in_num;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "snippets/lowered/loop_manager.hpp"
#include "emitters/plugin/x64/utils.hpp"
#include "nodes/common/cpu_memcpy.h"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define DTYPE_CAST(X) static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(X))
Expand Down Expand Up @@ -233,6 +234,9 @@ void BrgemmCopyBKernel::emit_brgemm_copy_b_kernel_call(size_t N, size_t K, size_
spill.postamble();
}

uintptr_t base_addr_src = 0;
uintptr_t base_addr_dst = 0;

void BrgemmCopyBKernel::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K) {
auto ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t();
ctx.current_N_blk = N;
Expand All @@ -244,8 +248,37 @@ void BrgemmCopyBKernel::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, cons
ctx.current_K_start = 0;
ctx.current_K_iters = K;

if (base_addr_src == 0)
base_addr_src = reinterpret_cast<uintptr_t>(src);
else
std::cout << "Stride from base_addr_src = " << reinterpret_cast<uintptr_t>(src) - base_addr_src << std::endl;
if (base_addr_dst == 0)
base_addr_dst = reinterpret_cast<uintptr_t>(dst);
else
std::cout << "Stride from base_addr_dst = " << reinterpret_cast<uintptr_t>(dst) - base_addr_dst << std::endl;

OV_CPU_JIT_EMITTER_ASSERT(kernel, "Kernel hasn't been created");
(*kernel)(&ctx);
if (std::getenv("REFERENCE")) {
(*kernel)(&ctx);
std::cout << "Ref Repacked, KN = " << K * N << std::endl;
const auto* data = reinterpret_cast<const bfloat16*>(dst);
for (size_t i = 0; i < K * N; ++i) {
std::cout << static_cast<float>(data[i]) << "\t";
}
std::cout << "\n";
} else {
auto srcPtr = static_cast<const uint8_t*>(src);
auto dstPtr = const_cast<uint8_t*>(static_cast<const uint8_t*>(dst));

auto copySize = K * N * sizeof(bfloat16);
cpu_memcpy(dstPtr, srcPtr, copySize);
std::cout << "Just copy, KN = " << K * N << std::endl;
const auto* data = reinterpret_cast<const bfloat16*>(dst);
for (size_t i = 0; i < K * N; ++i) {
std::cout << static_cast<float>(data[i]) << "\t";
}
std::cout << "\n";
}
}

BrgemmCopyBKernelExecutor::BrgemmCopyBKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmCopyBKernelConfig config)
Expand Down
6 changes: 0 additions & 6 deletions src/plugins/intel_cpu/src/nodes/reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
#include <common/primitive_hashing_utils.hpp>
#include <shape_inference/shape_inference_pass_through.hpp>

#include "convert.h"
#include "cpu/x64/cpu_isa_traits.hpp"
#include "nodes/common/cpu_convert.h"
#include "nodes/common/cpu_memcpy.h"
#include "nodes/common/reorder_prim.h"
#include "openvino/core/parallel.hpp"
#include "shape_inference/shape_inference_pass_through.hpp"
#include "utils/precision_support.h"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/transpose_list.hpp"
Expand Down
56 changes: 56 additions & 0 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
//
#include "subgraph.h"

#include "nodes/reorder.h"
#include "nodes/common/reorder_prim.h"
#include "memory_desc/cpu_memory_desc_utils.h"
#include "common/primitive_hashing_utils.hpp"
#include "dnnl_extension_utils.h"
#include "onednn/dnnl.h"
Expand Down Expand Up @@ -756,6 +759,36 @@ void Subgraph::optimizeIR() {
void Subgraph::prepareParams() {
const auto& cache = context->getParamsCache();

const auto& input_shape = getSrcMemoryAtPort(0)->getDescPtr()->getShape().getStaticDims();
const auto& b_shape = getSrcMemoryAtPort(1)->getDescPtr()->getShape().getStaticDims();

// Note: this code was tested only on static shapes, in case of dynamic M will most likely fail
const auto M = DnnlExtensionUtils::convertToDnnlDim(*++input_shape.rbegin());
const auto K = DnnlExtensionUtils::convertToDnnlDim(*input_shape.rbegin());
const auto N = DnnlExtensionUtils::convertToDnnlDim(*b_shape.rbegin());
const auto B_2 = DnnlExtensionUtils::convertToDnnlDim(*++b_shape.begin());

auto get_wei_desc = [&]() {
const auto inputDesc = dnnl::memory::desc({1, M, K}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::abc);
// Notes:
// 1. "Any" layout must be set to enable weights layout selection heuristics
// 2. Shape must be in NK order (even if the original shape is in KN order)
const auto BDesc = dnnl::memory::desc({B_2, K, N}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::any);
const auto outputDesc = dnnl::memory::desc({B_2, M, N}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::abc);

// Hack: we create inner product primitive just to know which weights layout was chosen by OneDNN heuristics
// Then, this layout is used in Snippets implementation
auto mm_desc = dnnl::matmul::primitive_desc(getEngine(), inputDesc, BDesc, outputDesc);
// Note: based on weights layout, it is necessary to set N block sizes inside Snippets.
// Example: in case of "AB16b32a" layout, N_block must be 32. K_block can be any
std::cout << "[ INFO ] matmul primitive selected the following B layout for BF16: "
<< DnnlExtensionUtils::makeDescriptor(mm_desc.weights_desc())->serializeFormat() << std::endl;
return DnnlExtensionUtils::makeDescriptor(mm_desc.weights_desc());
};

// auto reorder = ov::intel_cpu::getReorderPrim(context->getParamsCache(), getEngine(), originalMemDesc->getDnnlDesc(), get_wei_desc());
requested_desc_b = get_wei_desc();

auto builder = [this, &cache](const SubgraphKey& key) -> std::shared_ptr<SubgraphExecutor> {
const auto& snippet = subgraph_attrs->snippet;

Expand Down Expand Up @@ -843,6 +876,29 @@ bool Subgraph::created() const {

void Subgraph::execute(dnnl::stream strm) {
OPENVINO_ASSERT(execPtr, "Can't execute Subgraph node. Primitive didn't created");
if (requested_desc_b) {
auto repacked_memory = std::make_shared<Memory>(getEngine(), requested_desc_b);
repacked_memory->load(*srcMemPtrs[1]);
if (!std::getenv("REFERENCE"))
srcMemPtrs[1] = repacked_memory;

// TODO: remove
const auto& input_shape = getSrcMemoryAtPort(0)->getDescPtr()->getShape().getStaticDims();
const auto& b_shape = getSrcMemoryAtPort(1)->getDescPtr()->getShape().getStaticDims();
const auto K = DnnlExtensionUtils::convertToDnnlDim(*input_shape.rbegin());
const auto N = DnnlExtensionUtils::convertToDnnlDim(*b_shape.rbegin());
auto* data = repacked_memory->getDataAs<const bfloat16>();
std::cout << "Repacked, KN = " << K * N << std::endl;
auto upper_bound = repacked_memory->getSize();
for (decltype(upper_bound) i = 0; i < upper_bound; ++i) {
std::cout << static_cast<float>(data[i]) << "\t";
if (static_cast<float>(data[i]) == 5.21875f) {
// std::cout << "Stride is found: " << i << std::endl;
upper_bound = i + K * N;
}
}
std::cout << "\n";
}
execPtr->exec(srcMemPtrs, dstMemPtrs);
}

Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class Subgraph : public Node {
mutable std::vector<VectorDims> in_shapes;

std::shared_ptr<SubgraphExecutor> execPtr = nullptr;

ov::intel_cpu::MemoryDescPtr requested_desc_b;
};

class Subgraph::SubgraphCodeGenerator {
Expand Down

0 comments on commit a1ef9da

Please sign in to comment.