Skip to content

Commit

Permalink
[XLA:SPMD] Use stable sort to fix a flaky test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678935741
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Sep 26, 2024
1 parent 0032f7c commit 03b93df
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
10 changes: 4 additions & 6 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1143,8 +1143,6 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
std::vector<int64_t>(halo_exchange_base_shape.rank(), 1)));
}

std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
// TODO(yuanzx): We are concatenating on each sharded dimension one at time,
// and in the second dimension (and beyond) we create halos by slicing the
// concat in the previous dimension, which is not optimal. We should generate
Expand All @@ -1162,18 +1160,18 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
// partition.
MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
input_shard_size, explicit_left_padding[dim], 1);
left_halo_size_functions[dim] =
OffsetCalculation left_halo_size_functions =
shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];

// Right halo.
MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
right_halo_size_functions[dim] =
OffsetCalculation right_halo_size_functions =
limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;

auto resharded = ExchangeHaloAndGetValidData(
visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding[dim],
visiting_hlo, halo_exchange_base_shape, left_halo_size_functions,
right_halo_size_functions, explicit_left_padding[dim],
padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim,
*halo_exchange_target, offsets_on_padded_shape[dim], pad_value,
partition_ordinals[dim], state_.collective_ops_creator,
Expand Down
14 changes: 7 additions & 7 deletions xla/service/spmd/spmd_partitioner_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -956,8 +956,7 @@ HloInstruction* ExchangeHaloCompact(
(i + 1) * input_shard_size + right_halo_size_function.Calculate(i);
max_window_size = std::max(max_window_size, limit - start);
while (next_start < limit) {
halos[i].emplace_back();
Halo& halo = halos[i].back();
Halo& halo = halos[i].emplace_back();
halo.my_index = i;
halo.halo_offset = next_start - start;
halo.start = next_start % input_shard_size;
Expand Down Expand Up @@ -1038,11 +1037,12 @@ HloInstruction* ExchangeHaloCompact(
// Sort halos that are from the same src according to halo_offset, so that
// they are more likely to have similar characteristics.
for (int64_t i = 0; i < src_to_dst.size(); ++i) {
absl::c_sort(src_to_dst[i], [&](const std::pair<int64_t, int64_t>& a,
const std::pair<int64_t, int64_t>& b) {
return halos[a.first][a.second].halo_offset <
halos[b.first][b.second].halo_offset;
});
absl::c_stable_sort(src_to_dst[i],
[&](const std::pair<int64_t, int64_t>& a,
const std::pair<int64_t, int64_t>& b) {
return halos[a.first][a.second].halo_offset <
halos[b.first][b.second].halo_offset;
});
}

// Build collective permutes with distinct src/dst values.
Expand Down

0 comments on commit 03b93df

Please sign in to comment.