-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-implement
ComputationClient::WaitDeviceOps
(#5811)
* Replace device_locks_ * block new ops while waiting * get lock before notifying cv * Hold onto device locks until wait is complete * formatting * shared_ptr -> unique_ptr * remove lock_device * construct counter inplace * improve logging * make operation_tracker private * formatting * Comments * absl span * address review comments * remove comment * actually make a vector for `returned_futures` * Revert "address review comments" This reverts commit 77f4b14. * naming * correct comment
- Loading branch information
1 parent
57e6035
commit 175e86b
Showing
6 changed files
with
222 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#include "torch_xla/csrc/runtime/operation_manager.h" | ||
|
||
#include <shared_mutex> | ||
|
||
#include "absl/types/span.h" | ||
#include "torch_xla/csrc/runtime/debug_macros.h" | ||
#include "torch_xla/csrc/runtime/tf_logging.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
OperationManager::OperationManager(absl::Span<const std::string> devices) { | ||
for (auto& device : devices) { | ||
op_counters_.try_emplace(device, device); | ||
} | ||
} | ||
|
||
OperationManager::OperationTracker::OperationTracker(Counter* counter) | ||
: counter_(counter) { | ||
XLA_CHECK(counter_); | ||
counter_->Increment(); | ||
} | ||
|
||
OperationManager::OperationTracker::~OperationTracker() { | ||
counter_->Decrement(); | ||
} | ||
|
||
std::unique_ptr<OperationManager::OperationTracker> | ||
OperationManager::StartOperation(std::string device) { | ||
return std::make_unique<OperationTracker>(&op_counters_.at(device)); | ||
} | ||
|
||
void OperationManager::WaitForDevices(absl::Span<const std::string> devices) { | ||
std::vector<std::unique_lock<std::shared_mutex>> locks; | ||
locks.reserve(devices.size()); | ||
|
||
for (const std::string& device_str : devices) { | ||
TF_VLOG(5) << "Blocking new operations on " << device_str; | ||
auto lock = op_counters_.at(device_str).BlockNewOperations(); | ||
locks.emplace_back(std::move(lock)); | ||
|
||
TF_VLOG(3) << "Waiting for device execution for " << device_str | ||
<< " to finish"; | ||
op_counters_.at(device_str).Wait(); | ||
TF_VLOG(3) << "Finished operations on device " << device_str; | ||
} | ||
} | ||
|
||
void OperationManager::Counter::Increment() { | ||
// Block new operations after BlockNewOperations() is called. count_ is | ||
// already atomic, so atomic so we don't need an exclusive lock to prevent | ||
// data races. | ||
std::shared_lock lock(pending_operations_mu_); | ||
auto current = count_.fetch_add(1, std::memory_order_acq_rel) + 1; | ||
TF_VLOG(5) << "Incremented operations for " << device_ << " to " << current; | ||
} | ||
|
||
void OperationManager::Counter::Decrement() { | ||
auto current = count_.fetch_sub(1, std::memory_order_acq_rel) - 1; | ||
TF_VLOG(5) << "Decremented operations for " << device_ << " to " << current; | ||
|
||
if (current == 0) { | ||
std::unique_lock cv_lock(cv_mu_); | ||
TF_VLOG(3) << "All operations complete for " << device_; | ||
cv_.notify_all(); | ||
} | ||
} | ||
|
||
std::unique_lock<std::shared_mutex> | ||
OperationManager::Counter::BlockNewOperations() { | ||
return std::unique_lock(pending_operations_mu_); | ||
} | ||
|
||
void OperationManager::Counter::Wait() { | ||
TF_VLOG(5) << "Waiting for " << count_ << " operations on " << device_; | ||
std::unique_lock cv_lock(cv_mu_); | ||
cv_.wait(cv_lock, | ||
[this] { return count_.load(std::memory_order_acquire) == 0; }); | ||
TF_VLOG(5) << "Done waiting for " << device_; | ||
} | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#ifndef XLA_CLIENT_OPERATION_MANAGER_H_ | ||
#define XLA_CLIENT_OPERATION_MANAGER_H_ | ||
|
||
#include <atomic> | ||
#include <condition_variable> | ||
#include <memory> | ||
#include <mutex> | ||
#include <shared_mutex> | ||
|
||
#include "absl/types/span.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
// Track inflight operations for each device. | ||
class OperationManager { | ||
public: | ||
OperationManager() = default; | ||
OperationManager(absl::Span<const std::string>); | ||
|
||
OperationManager(const OperationManager&) = delete; | ||
OperationManager& operator=(const OperationManager&) = delete; | ||
|
||
OperationManager(OperationManager&&) = default; | ||
OperationManager& operator=(OperationManager&&) = default; | ||
|
||
class Counter { | ||
public: | ||
Counter(const std::string& device) : device_(device){}; | ||
|
||
Counter(const Counter&) = delete; | ||
Counter& operator=(const Counter&) = delete; | ||
|
||
// Register a new operation. Blocks if `BlockNewOperations` has been called. | ||
void Increment(); | ||
|
||
// Mark an inflight task completed. | ||
void Decrement(); | ||
|
||
// Wait until all operations are complete. Does not block new operations | ||
// (see BlockNewOperations). | ||
void Wait(); | ||
|
||
// Returns a lock that prevents new operations on the device. | ||
std::unique_lock<std::shared_mutex> BlockNewOperations(); | ||
|
||
private: | ||
std::string device_; | ||
|
||
std::shared_mutex pending_operations_mu_; | ||
std::atomic<int64_t> count_{0}; | ||
|
||
std::mutex cv_mu_; | ||
std::condition_variable cv_; | ||
}; | ||
|
||
class OperationTracker { | ||
public: | ||
// Register an operation in the `counter_`. | ||
OperationTracker(Counter* counter); | ||
|
||
// Mark an operation complete in `counter_`. | ||
~OperationTracker(); | ||
|
||
OperationTracker(const OperationTracker&) = delete; | ||
OperationTracker& operator=(const OperationTracker&) = delete; | ||
|
||
private: | ||
std::string device_; | ||
Counter* counter_; | ||
}; | ||
|
||
// Register a new operation for `device`. | ||
std::unique_ptr<OperationTracker> StartOperation(std::string device); | ||
|
||
// Wait for all device execution to complete on devices. | ||
void WaitForDevices(absl::Span<const std::string> devices); | ||
|
||
private: | ||
std::unordered_map<std::string, Counter> op_counters_; | ||
}; | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.