From 0a06fa367027f6a41eaada8919c70b217bb13d15 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Fri, 6 Oct 2023 11:56:26 +0200 Subject: [PATCH] SAM experiments --- .../x64/pass/lowered/brgemm_blocking.cpp | 89 +++++++++++-------- .../pass/set_brgemm_cpu_blocking_params.cpp | 63 +++++++------ 2 files changed, 87 insertions(+), 65 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index 65ca24cdaa0d7d..f1e1428ea8ff17 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -46,6 +46,7 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { }; bool modified = false; + size_t count = 0; for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { const auto& brgemm_expr = *expr_it; const auto brgemm = ov::as_type_ptr(brgemm_expr->get_node()); @@ -168,50 +169,60 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { } }; - try { - if (auto order = std::getenv("ORDER")) { - if (std::string(order) == "mnk") { - apply_m_blocking(); - apply_n_blocking(); - apply_k_blocking(); - std::cout << "[ INFO ] Blocking order: mnk\n"; - } else if (std::string(order) == "mkn") { - apply_m_blocking(); - apply_k_blocking(); - apply_n_blocking(); - std::cout << "[ INFO ] Blocking order: mkn\n"; - } else if (std::string(order) == "nmk") { - apply_n_blocking(); - apply_m_blocking(); - apply_k_blocking(); - std::cout << "[ INFO ] Blocking order: nmk\n"; - } else if (std::string(order) == "nkm") { - apply_n_blocking(); - apply_k_blocking(); - apply_m_blocking(); - std::cout << "[ INFO ] Blocking order: nkm\n"; - } else if (std::string(order) == "kmn") { - apply_k_blocking(); - apply_m_blocking(); - apply_n_blocking(); - std::cout << "[ INFO ] Blocking order: kmn\n"; - } else if (std::string(order) == "knm") { - apply_k_blocking(); - apply_n_blocking(); - apply_m_blocking(); - std::cout << "[ INFO ] Blocking order: knm\n"; + auto apply_order = [&](const char* env_name) { + try { + std::cout << "[ INFO ] Blocking " << env_name << ": "; + if (auto order = std::getenv(env_name)) { + if (std::string(order) == "mnk") { + apply_m_blocking(); + apply_n_blocking(); + apply_k_blocking(); + std::cout << "mnk\n"; + } else if (std::string(order) == "mkn") { + apply_m_blocking(); + apply_k_blocking(); + apply_n_blocking(); + std::cout << "mkn\n"; + } else if (std::string(order) == "nmk") { + apply_n_blocking(); + apply_m_blocking(); + apply_k_blocking(); + std::cout << "nmk\n"; + } else if (std::string(order) == "nkm") { + apply_n_blocking(); + apply_k_blocking(); + apply_m_blocking(); + std::cout << "nkm\n"; + } else if (std::string(order) == "kmn") { + apply_k_blocking(); + apply_m_blocking(); + apply_n_blocking(); + std::cout << "kmn\n"; + } else if (std::string(order) == "knm") { + apply_k_blocking(); + apply_n_blocking(); + apply_m_blocking(); + std::cout << "knm\n"; + } else { + throw "wrong blocking order"; + } } else { throw "wrong blocking order"; } - } else { - throw "wrong blocking order"; + } catch (...) { + std::cout << "fallback - knm is chosen\n"; + apply_k_blocking(); + apply_n_blocking(); + apply_m_blocking(); } - } catch(...) { - std::cout << "[ WARNING ] Blocking order fallback: knm is chosen\n"; - apply_k_blocking(); - apply_n_blocking(); - apply_m_blocking(); + }; + + if (count == 0) { + apply_order("ORDER0"); + } else { + apply_order("ORDER1"); } + count++; brgemm_expr->get_input_port_descriptor(0)->set_subtensor(input_0_subtensor); brgemm_expr->get_input_port_descriptor(1)->set_subtensor(input_1_subtensor); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp index f1f5f4850b34ab..22b871765b623c 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp @@ -72,36 +72,47 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { size_t brgemm_block_size_m = 0; size_t brgemm_block_size_k = 0; size_t brgemm_block_size_n = 0; - try { - if (auto m = std::getenv("M")) { - brgemm_block_size_m = std::atoi(m); + auto apply_blocking_parameters = [&](const std::shared_ptr& node) { + auto name = node->get_friendly_name(); + std::cout << name << std::endl; + char* m; char* k; char* n; + if (name.find("MatMul_1") != std::string::npos) { + m = "M1"; k = "K1"; n = "N1"; + } else { + m = "M0"; k = "K0"; n = "N0"; } - if (auto k = std::getenv("K")) { - brgemm_block_size_k = std::atoi(k); + try { + if (auto m_block = std::getenv(m)) { + brgemm_block_size_m = std::atoi(m_block); + } + if (auto k_block = std::getenv(k)) { + brgemm_block_size_k = std::atoi(k_block); + } + if (auto n_block = std::getenv(n)) { + brgemm_block_size_n = std::atoi(n_block); + } + if (brgemm_block_size_m == 0 || brgemm_block_size_k == 0 || brgemm_block_size_n == 0) { + throw "incorrect blocking params"; + } + std::cout << "[ INFO ] Blocking: env variables\n"; + } catch (...) { + std::cout << "[ INFO ] Blocking: fallback\n"; + brgemm_block_size_m = get_block_size_m(M); + brgemm_block_size_k = get_block_size_k(K); + brgemm_block_size_n = get_block_size_n(N); } - if (auto n = std::getenv("N")) { - brgemm_block_size_n = std::atoi(n); - } - if (brgemm_block_size_m == 0 || brgemm_block_size_k == 0 || brgemm_block_size_n == 0) { - throw "incorrect blocking params"; - } - std::cout << "[ INFO ] Blocking: env variables\n"; - } catch (...) { - std::cout << "[ INFO ] Blocking: fallback\n"; - brgemm_block_size_m = get_block_size_m(M); - brgemm_block_size_k = get_block_size_k(K); - brgemm_block_size_n = get_block_size_n(N); - } - if (input_1_precision != ov::element::f32) { - std::cout << "[ WARNING ] non f32 precision: K & N blocking params are ignored\n"; - brgemm_block_size_k = K; - brgemm_block_size_n = N; - } + if (input_1_precision != ov::element::f32) { + std::cout << "[ WARNING ] non f32 precision: K & N blocking params are ignored\n"; + brgemm_block_size_k = K; + brgemm_block_size_n = N; + } - std::cout << "\tM = " << brgemm_block_size_m << "\n"; - std::cout << "\tK = " << brgemm_block_size_k << "\n"; - std::cout << "\tN = " << brgemm_block_size_n << "\n"; + std::cout << "\t" << m << " = " << brgemm_block_size_m << "\n"; + std::cout << "\t" << k << " = " << brgemm_block_size_k << "\n"; + std::cout << "\t" << n << " = " << brgemm_block_size_n << "\n"; + }; + apply_blocking_parameters(brgemm); brgemm->set_m_block_size(brgemm_block_size_m); brgemm->set_k_block_size(brgemm_block_size_k);