Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DefaultValuedOptionalAttr and use_global_device_ids #272

Merged
merged 2 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def StableHLO_SendOp : StableHLO_Op<"send", []> {
Variadic<HLO_Tensor>:$inputs,
HLO_Token:$token,
StableHLO_ChannelHandle:$channel_handle,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
DefaultValuedOptionalAttr<BoolAttr, "false">:$is_host_transfer
);

let results = (outs HLO_Token);
Expand All @@ -1040,7 +1040,7 @@ def StableHLO_RecvOp : StableHLO_Op<"recv", []> {
let arguments = (ins
HLO_Token:$token,
StableHLO_ChannelHandle:$channel_handle,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
DefaultValuedOptionalAttr<BoolAttr, "false">:$is_host_transfer
);

let results = (outs Variadic<HLO_TensorOrToken>);
Expand Down Expand Up @@ -1242,7 +1242,8 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", [SameOperandsAndResultEle
HLO_Tensor:$operand,
I64Attr:$all_gather_dim,
I64ElementsAttr:$replica_groups,
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle,
UnitAttr:$use_global_device_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking about the logistics of adding use_global_device_ids, given that: 1) we promised compatibility guarantees for StableHLO in #1, 2) we have a pending proposal in #115 to define the exact extent of these guarantees, 3) we have another pending proposal in #196 to define the StableHLO evolution process, 4) we don't yet have a spec for this op.

Given that this is a non-controversial backward-compatible change, and that at the moment we don't have policies that govern opset changes, I'm inclined to approve it. Let me just request another review from @GleasonK - our compatibility expert - and if he signs off, let's merge.

The "non-controversial" part is a judgement call, given that this change is synchronized with MHLO and isn't tied to functionality private to XLA (this functionality is used by JAX). Another example of a change that seems similarly non-controversial is #235. In the future, we'll have clear policies which significantly reduce the role of judgement calls for opset changes, but at the moment we're playing it by ear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-controversial and backward compatible
Agree. These changes are also (somewhat) forward compatible since default valued attributes do not need to be present in the input IR.

The exception would be "If an op uses "use_global_device_ids" we should warn that it is may not be forward compatible, since this is a new feature, and I'm guessing ignoring the value in a previous version could lead to some semantic differences? The machinery for this warning is not in place yet, but should be soon. If no semantic difference would be caused by ignoring the attr, then probably ok to approve. Interested in your thoughts @burmako.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"I'm guessing ignoring the value in a previous version could lead to some semantic differences". I agree that it will lead to semantic differences. The old consumer will likely ignore the (from its perspective) unregistered use_global_device_ids attribute, which will result in a semantic difference.

On the other hand, the only piece of documentation for StableHLO compatibility guarantees is the "Backward compatible ML compute opset inspired by HLO/MHLO" tagline on our homepage. #1 also talks about backward compatibility only. #115 aims to provide stronger guarantees, but it's still under review.

Moreover, the work of migrating MHLO users to StableHLO is still ongoing, so I don't think we have anyone at the moment who can rely on forward compatibility of StableHLO in the first place.

Given that, I think that both de jure and de facto we have good grounds for approving this change, and that would be my recommendation.

);
let results = (outs HLO_Tensor);
let hasVerifier = 1;
Expand Down Expand Up @@ -1287,7 +1288,8 @@ def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter",
HLO_Tensor:$operand,
I64Attr:$scatter_dimension,
I64ElementsAttr:$replica_groups,
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle,
UnitAttr:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor);
Expand Down Expand Up @@ -1746,7 +1748,7 @@ def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
}];
let arguments = (ins
HLO_FpOrComplexTensor:$a,
DefaultValuedAttr<BoolAttr, "false">:$lower
DefaultValuedOptionalAttr<BoolAttr, "false">:$lower
);

let results = (outs HLO_FpOrComplexTensor:$result);
Expand Down Expand Up @@ -1957,15 +1959,15 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
let arguments = (ins
Variadic<HLO_TensorOrTokenOrTuple>:$inputs,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedStrAttr<StrAttr, "">:$backend_config,
// TODO(b/189822916): Remove this field when all clients are migrated to
// the status-returning API.
DefaultValuedAttr<
DefaultValuedOptionalAttr<
StableHLO_CustomCallApiVersionAttr,
"::mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL">:
$api_version,
DefaultValuedAttr<StableHLO_FlatSymbolRefArrayAttr, "{}">:$called_computations,
DefaultValuedOptionalAttr<StableHLO_FlatSymbolRefArrayAttr, "{}">:$called_computations,
OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts,
OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts
);
Expand Down Expand Up @@ -2104,7 +2106,7 @@ def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify, NoSide
HLO_IntTensor:$start_indices,
StableHLO_GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$slice_sizes,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted
);

let results = (outs HLO_Tensor);
Expand Down Expand Up @@ -2227,8 +2229,8 @@ def StableHLO_ScatterOp: StableHLO_Op<"scatter", [SameVariadicOperandSize, Recur
TensorOf<[AnyInteger, Index]>:$scatter_indices,
Variadic<HLO_Tensor>:$updates,
StableHLO_ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedOptionalAttr<BoolAttr, "false">:$unique_indices
);

let regions = (region SizedRegion<1>:$update_computation);
Expand Down Expand Up @@ -2343,8 +2345,8 @@ def StableHLO_SortOp : StableHLO_Op<"sort",
}];
let arguments = (ins
Variadic<HLO_Tensor>:$inputs,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
DefaultValuedOptionalAttr<I64Attr, "-1">:$dimension,
DefaultValuedOptionalAttr<BoolAttr, "false">:$is_stable
);

let results = (outs Variadic<HLO_Tensor>);
Expand Down Expand Up @@ -2845,7 +2847,7 @@ def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
HLO_IntTensor:$start_indices,
HLO_IntTensor:$slice_sizes,
StableHLO_GatherDimensionNumbers:$dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted
);
let results = (outs HLO_Tensor);
}
Expand Down
6 changes: 4 additions & 2 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> {
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32>
scatter_dimension = 1 : i64,
use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32>
func.return %0 : tensor<4x4xf32>
}

Expand Down Expand Up @@ -329,7 +330,8 @@ func.func @allgather_dynamic_gather_dim(%arg0: tensor<128x32xf32>) -> tensor<128
%0 = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
use_global_device_ids
} : (tensor<128x32xf32>) -> tensor<128x?xf32>
func.return %0 : tensor<128x?xf32>
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/print_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func.func @no_attr_ops(%arg0 : tensor<4xf32>, %arg1 : !stablehlo.token,
// CHECK-LABEL: func @multiple_attr_ops
func.func @multiple_attr_ops(%arg0 : tensor<3x4xf32>) -> () {
// CHECK: %0 = stablehlo.reduce_precision %arg0, format = e8m10 : tensor<3x4xf32>
// CHECK-NEXT: %1 = stablehlo.custom_call "foo"(%arg0, %arg0) {api_version = 1 : i32, backend_config = "bar", called_computations = [], has_side_effect = true} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<1x2x3xf32>
// CHECK-NEXT: %1 = stablehlo.custom_call "foo"(%arg0, %arg0) {backend_config = "bar", has_side_effect = true} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<1x2x3xf32>
%0 = "stablehlo.reduce_precision"(%arg0) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
%1 = "stablehlo.custom_call"(%arg0, %arg0) {backend_config = "bar", call_target_name = "foo", has_side_effect = true} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<1x2x3xf32>
"stablehlo.return"() : () -> ()
Expand Down