Skip to content

Commit

Permalink
SAM experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 6, 2023
1 parent e434466 commit 0a06fa3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
};

bool modified = false;
size_t count = 0;
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
const auto& brgemm_expr = *expr_it;
const auto brgemm = ov::as_type_ptr<ov::intel_cpu::BrgemmCPU>(brgemm_expr->get_node());
Expand Down Expand Up @@ -168,50 +169,60 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
}
};

try {
if (auto order = std::getenv("ORDER")) {
if (std::string(order) == "mnk") {
apply_m_blocking();
apply_n_blocking();
apply_k_blocking();
std::cout << "[ INFO ] Blocking order: mnk\n";
} else if (std::string(order) == "mkn") {
apply_m_blocking();
apply_k_blocking();
apply_n_blocking();
std::cout << "[ INFO ] Blocking order: mkn\n";
} else if (std::string(order) == "nmk") {
apply_n_blocking();
apply_m_blocking();
apply_k_blocking();
std::cout << "[ INFO ] Blocking order: nmk\n";
} else if (std::string(order) == "nkm") {
apply_n_blocking();
apply_k_blocking();
apply_m_blocking();
std::cout << "[ INFO ] Blocking order: nkm\n";
} else if (std::string(order) == "kmn") {
apply_k_blocking();
apply_m_blocking();
apply_n_blocking();
std::cout << "[ INFO ] Blocking order: kmn\n";
} else if (std::string(order) == "knm") {
apply_k_blocking();
apply_n_blocking();
apply_m_blocking();
std::cout << "[ INFO ] Blocking order: knm\n";
auto apply_order = [&](const char* env_name) {
try {
std::cout << "[ INFO ] Blocking " << env_name << ": ";
if (auto order = std::getenv(env_name)) {
if (std::string(order) == "mnk") {
apply_m_blocking();
apply_n_blocking();
apply_k_blocking();
std::cout << "mnk\n";
} else if (std::string(order) == "mkn") {
apply_m_blocking();
apply_k_blocking();
apply_n_blocking();
std::cout << "mkn\n";
} else if (std::string(order) == "nmk") {
apply_n_blocking();
apply_m_blocking();
apply_k_blocking();
std::cout << "nmk\n";
} else if (std::string(order) == "nkm") {
apply_n_blocking();
apply_k_blocking();
apply_m_blocking();
std::cout << "nkm\n";
} else if (std::string(order) == "kmn") {
apply_k_blocking();
apply_m_blocking();
apply_n_blocking();
std::cout << "kmn\n";
} else if (std::string(order) == "knm") {
apply_k_blocking();
apply_n_blocking();
apply_m_blocking();
std::cout << "knm\n";
} else {
throw "wrong blocking order";
}
} else {
throw "wrong blocking order";
}
} else {
throw "wrong blocking order";
} catch (...) {
std::cout << "fallback - knm is chosen\n";
apply_k_blocking();
apply_n_blocking();
apply_m_blocking();
}
} catch(...) {
std::cout << "[ WARNING ] Blocking order fallback: knm is chosen\n";
apply_k_blocking();
apply_n_blocking();
apply_m_blocking();
};

if (count == 0) {
apply_order("ORDER0");
} else {
apply_order("ORDER1");
}
count++;

brgemm_expr->get_input_port_descriptor(0)->set_subtensor(input_0_subtensor);
brgemm_expr->get_input_port_descriptor(1)->set_subtensor(input_1_subtensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,47 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
size_t brgemm_block_size_m = 0;
size_t brgemm_block_size_k = 0;
size_t brgemm_block_size_n = 0;
try {
if (auto m = std::getenv("M")) {
brgemm_block_size_m = std::atoi(m);
auto apply_blocking_parameters = [&](const std::shared_ptr<ov::Node>& node) {
auto name = node->get_friendly_name();
std::cout << name << std::endl;
char* m; char* k; char* n;
if (name.find("MatMul_1") != std::string::npos) {
m = "M1"; k = "K1"; n = "N1";
} else {
m = "M0"; k = "K0"; n = "N0";
}
if (auto k = std::getenv("K")) {
brgemm_block_size_k = std::atoi(k);
try {
if (auto m_block = std::getenv(m)) {
brgemm_block_size_m = std::atoi(m_block);
}
if (auto k_block = std::getenv(k)) {
brgemm_block_size_k = std::atoi(k_block);
}
if (auto n_block = std::getenv(n)) {
brgemm_block_size_n = std::atoi(n_block);
}
if (brgemm_block_size_m == 0 || brgemm_block_size_k == 0 || brgemm_block_size_n == 0) {
throw "incorrect blocking params";
}
std::cout << "[ INFO ] Blocking: env variables\n";
} catch (...) {
std::cout << "[ INFO ] Blocking: fallback\n";
brgemm_block_size_m = get_block_size_m(M);
brgemm_block_size_k = get_block_size_k(K);
brgemm_block_size_n = get_block_size_n(N);
}
if (auto n = std::getenv("N")) {
brgemm_block_size_n = std::atoi(n);
}
if (brgemm_block_size_m == 0 || brgemm_block_size_k == 0 || brgemm_block_size_n == 0) {
throw "incorrect blocking params";
}
std::cout << "[ INFO ] Blocking: env variables\n";
} catch (...) {
std::cout << "[ INFO ] Blocking: fallback\n";
brgemm_block_size_m = get_block_size_m(M);
brgemm_block_size_k = get_block_size_k(K);
brgemm_block_size_n = get_block_size_n(N);
}

if (input_1_precision != ov::element::f32) {
std::cout << "[ WARNING ] non f32 precision: K & N blocking params are ignored\n";
brgemm_block_size_k = K;
brgemm_block_size_n = N;
}
if (input_1_precision != ov::element::f32) {
std::cout << "[ WARNING ] non f32 precision: K & N blocking params are ignored\n";
brgemm_block_size_k = K;
brgemm_block_size_n = N;
}

std::cout << "\tM = " << brgemm_block_size_m << "\n";
std::cout << "\tK = " << brgemm_block_size_k << "\n";
std::cout << "\tN = " << brgemm_block_size_n << "\n";
std::cout << "\t" << m << " = " << brgemm_block_size_m << "\n";
std::cout << "\t" << k << " = " << brgemm_block_size_k << "\n";
std::cout << "\t" << n << " = " << brgemm_block_size_n << "\n";
};
apply_blocking_parameters(brgemm);

brgemm->set_m_block_size(brgemm_block_size_m);
brgemm->set_k_block_size(brgemm_block_size_k);
Expand Down

0 comments on commit 0a06fa3

Please sign in to comment.