Skip to content

Commit

Permalink
[GPU] CompilationContext refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Feb 23, 2024
1 parent bdf6fce commit fa19dd4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ICompilationContext {
using Task = std::function<void()>;
virtual void push_task(kernel_impl_params key, Task&& task) = 0;
virtual void remove_keys(std::vector<kernel_impl_params>&& keys) = 0;
virtual bool has_key(const kernel_impl_params& key) const = 0;
virtual ~ICompilationContext() = default;
virtual bool is_stopped() = 0;
virtual void cancel() = 0;
Expand Down
24 changes: 12 additions & 12 deletions src/plugins/intel_gpu/src/graph/compilation_context.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2022-2023 Intel Corporation
// Copyright (C) 2022-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -25,11 +25,10 @@ class CompilationContext : public ICompilationContext {
auto promise = std::make_shared<std::promise<void>>();

std::lock_guard<std::mutex> lock(_mutex);
futures.emplace_back(promise->get_future());

if (_task_keys.find(key) == _task_keys.end()) {
if (_task_executor != nullptr) {
_task_keys.insert(key);
_task_keys.insert({key, promise->get_future()});
_task_executor->run([task, promise] {
task();
promise->set_value();
Expand All @@ -49,6 +48,10 @@ class CompilationContext : public ICompilationContext {
}
}

bool has_key(const kernel_impl_params& key) const override {
return _task_keys.count(key) != 0;
}

~CompilationContext() noexcept {
cancel();
}
Expand All @@ -64,11 +67,7 @@ class CompilationContext : public ICompilationContext {
_stop_compilation = true;

// Flush all remaining tasks.
for (auto&& future : futures) {
if (future.valid()) {
future.wait();
}
}
wait_all();

{
std::lock_guard<std::mutex> lock(_mutex);
Expand All @@ -79,18 +78,19 @@ class CompilationContext : public ICompilationContext {
}

void wait_all() override {
for (auto&& future : futures) {
future.wait();
for (auto&& key_future : _task_keys) {
if (key_future.second.valid()) {
key_future.second.wait();
}
}
}

private:
ov::threading::IStreamsExecutor::Config _task_executor_config;
std::shared_ptr<ov::threading::IStreamsExecutor> _task_executor;
std::mutex _mutex;
std::unordered_set<kernel_impl_params, kernel_impl_params::Hasher> _task_keys;
std::unordered_map<kernel_impl_params, std::future<void>, kernel_impl_params::Hasher> _task_keys;
std::atomic_bool _stop_compilation{false};
std::vector<std::future<void>> futures;
};

std::shared_ptr<ICompilationContext> ICompilationContext::create(ov::threading::IStreamsExecutor::Config task_executor_config) {
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ bool primitive_inst::update_impl() {
if (use_async_compilation()) {
auto& compilation_context = prog->get_compilation_context();
compilation_context.push_task(updated_params_no_dyn_pad, [this, &compilation_context, updated_params_no_dyn_pad]() {
if (compilation_context.is_stopped())
if (compilation_context.is_stopped() || !compilation_context.has_key(updated_params_no_dyn_pad))
return;
auto _program = get_network().get_program();
auto& cache = _program->get_implementations_cache();
Expand Down

0 comments on commit fa19dd4

Please sign in to comment.