Skip to content

Commit

Permalink
Re-implement ComputationClient::WaitDeviceOps (#5811)
Browse files Browse the repository at this point in the history
* Replace device_locks_

* block new ops while waiting

* get lock before notifying cv

* Hold onto device locks until wait is complete

* formatting

* shared_ptr -> unique_ptr

* remove lock_device

* construct counter inplace

* improve logging

* make operation_tracker private

* formatting

* Comments

* absl span

* address review comments

* remove comment

* actually make a vector for `returned_futures`

* Revert "address review comments"

This reverts commit 77f4b14.

* naming

* correct comment
  • Loading branch information
will-cromar authored Nov 21, 2023
1 parent 57e6035 commit 175e86b
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 90 deletions.
13 changes: 13 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cc_library(
":computation_client",
":debug_macros",
":env_vars",
":operation_manager",
":profiler",
":stablehlo_helper",
":tensor_source",
Expand Down Expand Up @@ -187,6 +188,18 @@ cc_library(
],
)

cc_library(
name = "operation_manager",
srcs = ["operation_manager.cc"],
hdrs = ["operation_manager.h"],
visibility = ["//visibility:private"],
deps = [
":debug_macros",
":tf_logging",
"@com_google_absl//absl/types:span",
],
)

# Profiler silently fails unless we link these backends
cc_library(
name = "profiler_backends",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class ComputationClient {

// Block until pass in devices' async operation are finished. If empty, all
// the local devices will be waited for.
virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0;
virtual void WaitDeviceOps(absl::Span<const std::string> devices) = 0;

// Check whether the XlaCoordinator has been initialized.
virtual bool CoordinatorInitialized() const = 0;
Expand Down
83 changes: 83 additions & 0 deletions torch_xla/csrc/runtime/operation_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "torch_xla/csrc/runtime/operation_manager.h"

#include <shared_mutex>

#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/tf_logging.h"

namespace torch_xla {
namespace runtime {

OperationManager::OperationManager(absl::Span<const std::string> devices) {
for (auto& device : devices) {
op_counters_.try_emplace(device, device);
}
}

OperationManager::OperationTracker::OperationTracker(Counter* counter)
: counter_(counter) {
XLA_CHECK(counter_);
counter_->Increment();
}

OperationManager::OperationTracker::~OperationTracker() {
counter_->Decrement();
}

std::unique_ptr<OperationManager::OperationTracker>
OperationManager::StartOperation(std::string device) {
return std::make_unique<OperationTracker>(&op_counters_.at(device));
}

void OperationManager::WaitForDevices(absl::Span<const std::string> devices) {
std::vector<std::unique_lock<std::shared_mutex>> locks;
locks.reserve(devices.size());

for (const std::string& device_str : devices) {
TF_VLOG(5) << "Blocking new operations on " << device_str;
auto lock = op_counters_.at(device_str).BlockNewOperations();
locks.emplace_back(std::move(lock));

TF_VLOG(3) << "Waiting for device execution for " << device_str
<< " to finish";
op_counters_.at(device_str).Wait();
TF_VLOG(3) << "Finished operations on device " << device_str;
}
}

void OperationManager::Counter::Increment() {
// Block new operations after BlockNewOperations() is called. count_ is
// already atomic, so atomic so we don't need an exclusive lock to prevent
// data races.
std::shared_lock lock(pending_operations_mu_);
auto current = count_.fetch_add(1, std::memory_order_acq_rel) + 1;
TF_VLOG(5) << "Incremented operations for " << device_ << " to " << current;
}

void OperationManager::Counter::Decrement() {
auto current = count_.fetch_sub(1, std::memory_order_acq_rel) - 1;
TF_VLOG(5) << "Decremented operations for " << device_ << " to " << current;

if (current == 0) {
std::unique_lock cv_lock(cv_mu_);
TF_VLOG(3) << "All operations complete for " << device_;
cv_.notify_all();
}
}

std::unique_lock<std::shared_mutex>
OperationManager::Counter::BlockNewOperations() {
return std::unique_lock(pending_operations_mu_);
}

void OperationManager::Counter::Wait() {
TF_VLOG(5) << "Waiting for " << count_ << " operations on " << device_;
std::unique_lock cv_lock(cv_mu_);
cv_.wait(cv_lock,
[this] { return count_.load(std::memory_order_acquire) == 0; });
TF_VLOG(5) << "Done waiting for " << device_;
}

} // namespace runtime
} // namespace torch_xla
86 changes: 86 additions & 0 deletions torch_xla/csrc/runtime/operation_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#ifndef XLA_CLIENT_OPERATION_MANAGER_H_
#define XLA_CLIENT_OPERATION_MANAGER_H_

#include <atomic>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <shared_mutex>

#include "absl/types/span.h"

namespace torch_xla {
namespace runtime {

// Track inflight operations for each device.
class OperationManager {
public:
OperationManager() = default;
OperationManager(absl::Span<const std::string>);

OperationManager(const OperationManager&) = delete;
OperationManager& operator=(const OperationManager&) = delete;

OperationManager(OperationManager&&) = default;
OperationManager& operator=(OperationManager&&) = default;

class Counter {
public:
Counter(const std::string& device) : device_(device){};

Counter(const Counter&) = delete;
Counter& operator=(const Counter&) = delete;

// Register a new operation. Blocks if `BlockNewOperations` has been called.
void Increment();

// Mark an inflight task completed.
void Decrement();

// Wait until all operations are complete. Does not block new operations
// (see BlockNewOperations).
void Wait();

// Returns a lock that prevents new operations on the device.
std::unique_lock<std::shared_mutex> BlockNewOperations();

private:
std::string device_;

std::shared_mutex pending_operations_mu_;
std::atomic<int64_t> count_{0};

std::mutex cv_mu_;
std::condition_variable cv_;
};

class OperationTracker {
public:
// Register an operation in the `counter_`.
OperationTracker(Counter* counter);

// Mark an operation complete in `counter_`.
~OperationTracker();

OperationTracker(const OperationTracker&) = delete;
OperationTracker& operator=(const OperationTracker&) = delete;

private:
std::string device_;
Counter* counter_;
};

// Register a new operation for `device`.
std::unique_ptr<OperationTracker> StartOperation(std::string device);

// Wait for all device execution to complete on devices.
void WaitForDevices(absl::Span<const std::string> devices);

private:
std::unordered_map<std::string, Counter> op_counters_;
};

} // namespace runtime
} // namespace torch_xla

#endif
119 changes: 36 additions & 83 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/operation_manager.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
Expand Down Expand Up @@ -206,10 +207,11 @@ PjRtComputationClient::PjRtComputationClient() {
global_ordinals_[device->id()] = global_ordinals_.size();
std::string device_str = PjRtDeviceToString(device);
string_to_device_.emplace(device_str, device);
device_locks_.emplace(device_str, std::make_unique<std::shared_mutex>());
}
// manually create the device_locks for SPMD device
device_locks_.emplace(spmd_device_str, std::make_unique<std::shared_mutex>());

auto tracked_devices = GetLocalDevices();
tracked_devices.emplace_back(spmd_device_str);
operation_manager_ = std::move(OperationManager(std::move(tracked_devices)));
}

PjRtComputationClient::~PjRtComputationClient() {
Expand Down Expand Up @@ -601,13 +603,24 @@ PjRtComputationClient::ExecuteComputation(
// Required as of cl/518733871
execute_options.use_major_to_minor_data_layout_for_callbacks = true;

TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device;
auto op_tracker = operation_manager_.StartOperation(device);
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device
<< " Done";

std::optional<xla::PjRtFuture<xla::Status>> returned_future;
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
pjrt_computation.executable
->ExecuteSharded(buffers, pjrt_device, execute_options,
returned_future)
.value();

returned_future->OnReady(std::move(
[timed, op_tracker = std::move(op_tracker)](xla::Status unused) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished";
}));

std::vector<DataPtr> datas;
datas.reserve(results.size());
for (auto& result : results) {
Expand All @@ -620,31 +633,6 @@ PjRtComputationClient::ExecuteComputation(
}
CreateDataHandlesCounter()->AddValue(datas.size());

thread::Schedule(std::move([&, this, device,
returned_future = std::move(*returned_future),
timed]() mutable {
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for "
<< device;
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready.
// TODO(JackCaoG): This lock should acquired outside of the lockfn and
// passed in. It is possible that lockfn started after ExecuteComputation
// released the xla_graph_executor lock, which will create a short windows
// where device is unlcoked while execution is still running.
auto lock = lock_device_shared(device);
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device
<< " Done";
// Signal that `ExecuteSharded` has completed for the ExecuteTime
// metric. Copies the `timed` shared pointer into the lambda.
XLA_CHECK(returned_future.IsValid())
<< "returned_future in ExecuteComputation is empty";
returned_future.OnReady(
[timed, lock = std::move(lock)](xla::Status unused) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished";
});
}));

TF_VLOG(1) << "Returning " << datas.size() << " results";
return datas;
}
Expand Down Expand Up @@ -704,6 +692,15 @@ PjRtComputationClient::ExecuteReplicated(
// Required as of cl/518733871
execute_options.use_major_to_minor_data_layout_for_callbacks = true;

// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready. Since this is the SPMD code path. There is no points to grab
// devices lock for every individual device.
TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for "
<< spmd_device_str;
auto op_tracker = operation_manager_.StartOperation(spmd_device_str);
TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for "
<< spmd_device_str << " Done";

std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures(
devices.size());
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results;
Expand All @@ -715,6 +712,13 @@ PjRtComputationClient::ExecuteReplicated(
->Execute(std::move(argument_handles), execute_options,
returned_futures)
.value();

(*returned_futures)[0].OnReady(
std::move([timed, op_tracker = std::move(op_tracker)](
xla::Status unused) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished";
}));
}

std::vector<std::vector<ComputationClient::DataPtr>> data_handles;
Expand Down Expand Up @@ -747,31 +751,6 @@ PjRtComputationClient::ExecuteReplicated(
}
}

thread::Schedule(std::move([&, this,
returned_futures = std::move(*returned_futures),
timed]() mutable {
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready. Since this is the SPMD code path. There is no points to grab
// devices lock for every individual device.
TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for "
<< spmd_device_str;
auto lock = lock_device_shared(spmd_device_str);
TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for "
<< spmd_device_str << " Done";
// Signal that `ExecuteReplicated` has completed for one of the devices
// the ExecuteReplicatedTime metric. Here, we assume that all devices
// will finish execution roughly at the same time, hence only use one of
// the returned_futures. Copies the `timed` shared pointer into the
// lambda.
XLA_CHECK(returned_futures[0].IsValid())
<< "returned_future in ExecuteReplicated is empty";
returned_futures[0].OnReady(
[timed, lock = std::move(lock)](xla::Status unused) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished";
});
}));

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
return data_handles;
Expand Down Expand Up @@ -826,37 +805,11 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice(
return pjrt_device;
}

std::shared_lock<std::shared_mutex> PjRtComputationClient::lock_device_shared(
const std::string& device) {
std::shared_lock lock(*device_locks_[device]);
return lock;
}

std::unique_lock<std::shared_mutex> PjRtComputationClient::lock_device(
const std::string& device) {
std::unique_lock lock(*device_locks_[device]);
return lock;
}

void PjRtComputationClient::WaitDeviceOps(
const std::vector<std::string>& devices) {
std::unordered_set<std::string> wait_devices;
if (!devices.empty()) {
for (auto& device_str : devices) {
wait_devices.insert(device_str);
}
} else {
for (auto& device_str : GetLocalDevices()) {
wait_devices.insert(device_str);
}
}
for (const std::string& device_str : wait_devices) {
TF_VLOG(3) << "Waiting for device execution for " << device_str
<< " to finish";
lock_device(device_str);
TF_VLOG(3) << "Waiting for device execution for " << device_str
<< " to finish.. Done";
}
absl::Span<const std::string> devices) {
TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", ");
operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices()
: devices);
}

std::map<std::string, Metric> PjRtComputationClient::GetMetrics() const {
Expand Down
Loading

0 comments on commit 175e86b

Please sign in to comment.