Skip to content

Commit

Permalink
Relax verification of ShardingParam so that a dimension can be sharde…
Browse files Browse the repository at this point in the history
…d over multiple axis.

PiperOrigin-RevId: 617356700
  • Loading branch information
ICGog authored and copybara-github committed Mar 20, 2024
1 parent b8de936 commit eac67fd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 14 deletions.
9 changes: 3 additions & 6 deletions xla/python/ifrt/ir/sharding_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,9 @@ absl::Status ShardingParam::verify() const {
break;
}
cum_size *= minor_to_major().axis_sizes[index];
if (cum_size > dim_shards()[dim_index]) {
return absl::InvalidArgumentError(absl::StrCat(
"Dimension #", dim_index, " of ", dim_shards()[dim_index],
" shards can't be assigned to the axes"));
} else if (cum_size == dim_shards()[dim_index]) {
cum_size = 1;
while (dim_index < dim_shards().size() &&
cum_size % dim_shards()[dim_index] == 0) {
cum_size /= dim_shards()[dim_index];
dim_index++;
}
}
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/ir/tests/verify_array.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func.func @array_requires_enough_devices() {
// -----

func.func @array_requires_shard_distributable_to_axes() {
// expected-error@+2 {{Dimension #1 of 2 shards can't be assigned to the axes}}
// expected-error@+2 {{Can't shard the dims 1x2 to the mesh of [0] on 3}}
%0 = builtin.unrealized_conversion_cast to
!ifrt.array<tensor<4x4xi32>, 1x2 to [0] on 3, [0,1,2]>
return
Expand Down
8 changes: 1 addition & 7 deletions xla/python/ifrt/support/sharding_conversions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ TEST(ShardingConversionsTest, VerifyIncorrectShardings) {
ShardingParam too_many_slices{/*dim_shards=*/{2, 2},
{/*permutation=*/{0}, /*axis_sizes=*/{2}}};
EXPECT_FALSE(too_many_slices.verify().ok());
ShardingParam cannot_distribute_slices{
/*dim_shards=*/{1, 2}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{3, 2}}};
EXPECT_FALSE(cannot_distribute_slices.verify().ok());
ShardingParam incorrect_permutation{
/*dim_shards=*/{4, 1},
{/*permutation=*/{0, 1, 1}, /*axis_sizes=*/{2, 2, 2}}};
Expand Down Expand Up @@ -197,10 +194,7 @@ TEST_P(HloShardingToShardingParamTest, HloShardingToShardingParam) {
TF_ASSERT_OK_AND_ASSIGN(
auto sharding_param,
ToShardingParam(param.hlo_sharding, param.rank, param.num_devices));
// We cannot verify sharding param because we're losing info about the
// axis_size during these conversions. While strictly some ShardingParam
// are invalid because they have more dims than axis, in practice this is not
// a problem because we can still correctly map the shards to the devices.
EXPECT_TRUE(sharding_param.verify().ok());
TF_ASSERT_OK_AND_ASSIGN(auto actual_hlo_sharding,
ToHloSharding(sharding_param));
EXPECT_EQ(param.hlo_sharding, actual_hlo_sharding);
Expand Down

0 comments on commit eac67fd

Please sign in to comment.