Skip to content

Commit

Permalink
Remove barrier test in pjrt/distributed with equivalent test coverage…
Browse files Browse the repository at this point in the history
… in coordination service's fork of client_server_test.

1. This is in preparation for a change in barrier semantics (don't require users to specify unique ids).

2. Moving forward, we want to shift new business logic tests to be in coord service's test suite.

This allows us to cover more edge cases with intrusive hooks (e.g. agent dtor tests) as well as non-Jax (i.e. TF) scenarios.

3. Only use pjrt/distributed tests + Jax multi-process Python tests for Jax-specific requirements (e.g. topology exchange contract), or to exercise the xla_client (nanobind) and pjrt/distributed codepaths to validate that args are plumbed correctly.

This reduces review burden on the Jax team.

PiperOrigin-RevId: 693752761
  • Loading branch information
Google-ML-Automation committed Nov 8, 2024
1 parent 727dedc commit f83c9f0
Showing 1 changed file with 0 additions and 31 deletions.
31 changes: 0 additions & 31 deletions xla/pjrt/distributed/client_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,37 +899,6 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) {
}
}

TEST_F(ClientServerTest, WaitAtBarrier_FailWithSameBarrierId) {
int num_nodes = 2;
StartService(num_nodes);

auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
TF_RETURN_IF_ERROR(client->Connect());

TF_RETURN_IF_ERROR(
client->WaitAtBarrier("barrier_1", kBarrierTimeout, std::nullopt));
TF_RETURN_IF_ERROR(
client->WaitAtBarrier("barrier_1", kBarrierTimeout, std::nullopt));

TF_RETURN_IF_ERROR(client->Shutdown());
return absl::OkStatus();
};

std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); });
}
}
for (int i = 0; i < num_nodes; ++i) {
EXPECT_EQ(statuses[i].code(), tsl::error::FAILED_PRECONDITION)
<< " node id: " << i;
}
}

TEST_F(ClientServerTest, WaitAtBarrierSubset_Succeeds) {
int num_nodes = 3;
StartService(num_nodes);
Expand Down

0 comments on commit f83c9f0

Please sign in to comment.