Skip to content

Commit

Permalink
[DEBUG]: BF16 validation
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Apr 9, 2023
1 parent 6aff821 commit c1b16bf
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,15 @@ std::map<std::string, std::shared_ptr<const T>> const_map_cast(const std::map<st

std::shared_ptr<IExecutableNetworkInternal> IInferencePlugin::LoadNetwork(
const CNNNetwork& orig_network,
const std::map<std::string, std::string>& config,
const std::map<std::string, std::string>& config_,
const std::shared_ptr<RemoteContext>& context) {
auto config = config_;

// TODO: just to test
ov::element::Type hint = ov::element::bf16;
config.emplace(std::string("INFERENCE_PRECISION_HINT"), hint.get_type_name());
std::cout << "IInferencePlugin::LoadNetwork: INFERENCE_PRECISION_HINT: " << hint.get_type_name() << std::endl;

std::shared_ptr<IExecutableNetworkInternal> impl;

// if IR `version` is not set, suppose it's IR v10 for old API
Expand Down
10 changes: 9 additions & 1 deletion src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,17 @@ void Config::applyDebugCapsProperties() {
}
#endif

void Config::readProperties(const std::map<std::string, std::string> &prop) {
void Config::readProperties(const std::map<std::string, std::string> &prop_) {
const auto streamExecutorConfigKeys = streamExecutorConfig.SupportedKeys();
const auto hintsConfigKeys = perfHintsConfig.SupportedKeys();

auto prop = prop_;

// TODO: just to test
ov::element::Type hint = ov::element::bf16;
prop.emplace(std::string("INFERENCE_PRECISION_HINT"), hint.get_type_name());
std::cout << "Config::readProperties: INFERENCE_PRECISION_HINT: " << hint.get_type_name() << std::endl;

for (const auto& kvp : prop) {
const auto& key = kvp.first;
const auto& val = kvp.second;
Expand Down
13 changes: 13 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph) {
graphNodes.push_back(outNode);
}

std::cout << "Graph::Replicate" << std::endl;
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) {
std::cout << "Graph::Replicate: BF16 is supported" << std::endl;
} else {
std::cout << "Graph::Replicate: BF16 is not supported" << std::endl;
}

if (getConfig().enforceBF16)
EnforceBF16();
}
Expand Down Expand Up @@ -308,6 +315,12 @@ void Graph::Replicate(const CNNNetwork &network) {
graphNodes.push_back(outNode);
}

if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) {
std::cout << "Graph::Replicate: BF16 is supported" << std::endl;
} else {
std::cout << "Graph::Replicate: BF16 is not supported" << std::endl;
}

if (getConfig().enforceBF16)
EnforceBF16();

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
// Note, MatMul decomposition will be ran later again for case if BF16 enforcement is not happened
pre_dialect.register_pass<ngraph::snippets::pass::MatMulToBrgemm>();
pre_dialect.register_pass<pass::EnforcePrecision>(element::f32, element::bf16);
std::cout << "Snippet::generate: EnforcePrecision was added in pipeline" << std::endl;
}

ov::pass::Manager post_dialect;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr<ov::Model>& f) {
// remove convert
op_parent->output(0).replace(op_parent->get_input_source_output(0));
was_updated = true;
std::cout << "EnforcePrecision::run_on_model: convertion removal: " << op->get_type_name() << ":" << op->get_friendly_name() << std::endl;
} else if (supported_precisions_to_enforce[index] != actual_precisions[index]) {
insert_convert(op->get_input_source_output(index), op, index, target);
was_updated = true;
std::cout << "EnforcePrecision::run_on_model: convertion insertion: " << op->get_type_name() << ":" << op->get_friendly_name() << std::endl;
}
}
}
Expand Down

0 comments on commit c1b16bf

Please sign in to comment.