Skip to content

Commit

Permalink
[WIP] Adapt the rest pipeline to FC tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Sep 30, 2024
1 parent 1eac8d6 commit e954f6b
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/common/snippets/src/pass/common_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ CommonOptimizations::CommonOptimizations(const SnippetsTokenization::Config& con
// Then if Subgraph contains FakeQuantize we enable specific transformation for quantized subgraphs.
ov::pass::Manager manager(get_pass_config(), "Snippets:CommonOptimizations");
REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::TransformConvertToConvertTruncation, true);
// Note: in case of FullyConnected, some common optimizations shouldn't be applied. At least, ExplicitTransposeMatMulInputs
REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::ExplicitTransposeMatMulInputs, is_domain_sensitive);
REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::CommonFakeQuantizeDecomposition, is_quantized);
REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::SoftmaxReshapeElimination, is_domain_sensitive);
Expand Down
19 changes: 17 additions & 2 deletions src/common/snippets/src/utils/loop_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,32 @@ inline void init_work_amount(const LoopInfoPtr& loop_info) {
} // namespace

void update_data_pointer_shifts(const UnifiedLoopInfoPtr& loop_info) {
static size_t loop_id = 0;
OPENVINO_ASSERT(loop_info != nullptr, "UnifiedLoopInfo is nullptr, nothing to update");
const auto work_amount = loop_info->get_work_amount();
const auto input_count = loop_info->get_input_count();
const auto output_count = loop_info->get_output_count();

auto update_shifts = [&work_amount, &input_count, &output_count](LoopPort& loop_port, UnifiedLoopInfo::LoopPortDesc& ptr_shifts_params) {
size_t idx = 0;
auto update_shifts = [&](LoopPort& loop_port, UnifiedLoopInfo::LoopPortDesc& ptr_shifts_params) {
ptr_shifts_params.ptr_increment = get_ptr_increment(loop_port, work_amount,
loop_port.expr_port->get_type() == ExpressionPort::Input ? input_count : output_count);
// Dirty hack
{
// Loop by K
if ((loop_id == 0 || loop_id == 4) && idx == 1) {
ptr_shifts_params.ptr_increment = 32; // increment = inner_N_block size
}
// Loop by N
if ((loop_id == 1 || loop_id == 5) && idx == 1) {
// increment = K dimension rounded by K block
ptr_shifts_params.ptr_increment = *++loop_port.expr_port->get_descriptor_ptr()->get_shape().rbegin();
}
}
ptr_shifts_params.finalization_offset = get_finalization_offset(work_amount, ptr_shifts_params.ptr_increment);
idx++;
};
loop_info->iterate_through_infos(update_shifts);
loop_id++;
}

void update_runtime_parameters(const UnifiedLoopInfoPtr& loop_info) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = brgemm_utils::repacking::compute_out_leading_dim(N, brgemm_node->get_input_element_type(1));

// hack to imitate blocking layout
LDB = std::getenv("N_b") ? std::atoi(std::getenv("N_b")) : 64;;
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}

Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc,
const auto& eng = context->getEngine();
const auto& format = dstWeightDesc->serializeFormat();

std::cout << "[ INFO ] prepareWeightsMemory info\n";
std::cout << "Format: from " << srcWeightDesc->serializeFormat() << " to " << dstWeightDesc->serializeFormat() << std::endl;

const auto privateWeightCache = context->getPrivateWeighCache();
OPENVINO_ASSERT(privateWeightCache, "privateWeightCache is nullptr");
if (privateWeightCache) {
Expand Down
47 changes: 47 additions & 0 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
#include "transformations/snippets/x64/shape_inference.hpp"
#endif

#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "memory_desc/cpu_memory_desc_utils.h"
#include "nodes/executors/dnnl/dnnl_utils.hpp"

#include "utils/cpu_utils.hpp"
#include "utils/ngraph_utils.hpp"

Expand Down Expand Up @@ -753,8 +757,51 @@ void Subgraph::optimizeIR() {
control_flow_config, control_flow_passes);
}

DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc,
const DnnlMemoryDescPtr dstDesc,
bool weightsNonTransposed) {
if (!weightsNonTransposed)
return srcDesc;

const auto& weiDesc = srcDesc->getDnnlDesc();
const auto reorderedWeiDesc =
dnnl::memory::desc{weiDesc.get_dims(), weiDesc.get_data_type(), dnnl::memory::format_tag::ba};
const auto transposedWeiDesc = DnnlExtensionUtils::makeDescriptor(reorderedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims()));
return transposedWeiDesc;
}

void Subgraph::prepareParams() {
const auto& cache = context->getParamsCache();
const auto& input_shape = getSrcMemoryAtPort(0)->getDescPtr()->getShape().getStaticDims();
const auto& weights_shape = getSrcMemoryAtPort(1)->getDescPtr()->getShape().getStaticDims();

const auto M = DnnlExtensionUtils::convertToDnnlDim(*++input_shape.rbegin());
const auto K = DnnlExtensionUtils::convertToDnnlDim(*input_shape.rbegin());
const auto N = DnnlExtensionUtils::convertToDnnlDim(*weights_shape.rbegin());
auto get_wei_desc = [&]() {
const auto inputDesc = dnnl::memory::desc({M, K}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);
const auto weightsDesc = dnnl::memory::desc({N, K}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::any);
const auto biasDesc = dnnl::memory::desc();
const auto outputDesc = dnnl::memory::desc({M, N}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);

auto fc_desc = dnnl::inner_product_forward::primitive_desc(context->getEngine(),
dnnl::prop_kind::forward_inference,
inputDesc,
weightsDesc,
biasDesc,
outputDesc);
auto weiDesc = DnnlExtensionUtils::makeDescriptor(fc_desc.weights_desc());
return weiDesc;
};
auto prepareWeightsMemory = [&]() {
const auto memory = getSrcMemoryAtPort(1);
auto originalMemDesc = DnnlExtensionUtils::makeDescriptor(dnnl::memory::desc({N, K}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab));
const auto blocked_desc = get_wei_desc();
originalMemDesc = makeTransposedWeightDescriptor(originalMemDesc, blocked_desc, true);
const auto exec_context = std::make_shared<ExecutorContext>(context, std::vector<impl_desc_type>{}, privateWeightCache);
srcMemPtrs[1] = utils::prepareWeightsMemory(originalMemDesc, blocked_desc, memory, exec_context, true);
};
prepareWeightsMemory();

auto builder = [this, &cache](const SubgraphKey& key) -> std::shared_ptr<SubgraphExecutor> {
const auto& snippet = subgraph_attrs->snippet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ std::tuple<size_t, size_t, size_t> BrgemmCPUBlocking::get_blocking_params(const
n_blk = get_full_dim_value();
k_blk = get_full_dim_value();
}
return std::make_tuple(m_blk, n_blk, k_blk);
const size_t M = std::getenv("M_b") ? std::atoi(std::getenv("M_b")) : m_blk;
const size_t K = std::getenv("K_b") ? std::atoi(std::getenv("K_b")) : k_blk;
const size_t N = std::getenv("N_b") ? std::atoi(std::getenv("N_b")) : n_blk;
return std::make_tuple(M, N, K);
}

SpecificIterationHandlers BrgemmCPUBlocking::get_k_loop_handlers(size_t work_amount, size_t block_size) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,13 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
CPU_SET_CALLBACK_COMMON(manager,
[](const_node_ptr &node) -> bool {
const auto consumers = node->get_output_target_inputs(0);
return std::all_of(consumers.begin(), consumers.end(), [](const ov::Input<ov::Node>& consumer) {
return !ov::is_type<ov::op::v0::MatMul>(consumer.get_node());
});
// Note: To tokenize MatMul with f16 const and f32 convert on weights in snippets,
// weights should be folded to f32 const before snippets, so KeepConstAndDecompression is disabled in this case
return true;
// const auto consumers = node->get_output_target_inputs(0);
// return std::all_of(consumers.begin(), consumers.end(), [](const ov::Input<ov::Node>& consumer) {
// return !ov::is_type<ov::op::v0::MatMul>(consumer.get_node());
// });
},
ov::pass::KeepConstAndDecompression);

Expand Down Expand Up @@ -876,6 +879,9 @@ void Transformations::PostLpt() {
}

void Transformations::MainSnippets(void) {
if (std::getenv("REFERENCE")) {
return;
}
auto is_supported_isa = [](){
#if defined(OPENVINO_ARCH_X86_64)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulEltwiseChainCascade, MatMulEltwise
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

const size_t M = std::getenv("M") ? std::atoi(std::getenv("M")) : 32;
const size_t K = std::getenv("K") ? std::atoi(std::getenv("K")) : 64;
const size_t N = std::getenv("N") ? std::atoi(std::getenv("N")) : 256;

std::vector<std::vector<ov::test::InputShape>> fc_input_shapes{
{
{PartialShape{}, {{1, 1, M, K}}},
{{}, {{K, N}}}
},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnected, MatMul,
::testing::Combine(
::testing::ValuesIn(fc_input_shapes),
::testing::ValuesIn(precisions(false)),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // MatMul
::testing::Values(1), // Tokenized MatMul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);
const auto& transpose_b_shapes = STATIC_SHAPES(
{{3, 3, 64, 64}, {3, 3, 64, 64}},
{{1, 1, 32, 128}, {1, 1, 64, 128}},
Expand Down

0 comments on commit e954f6b

Please sign in to comment.