Skip to content

Commit

Permalink
Improve shape refinement to not require inlining.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567326536
  • Loading branch information
gnecula authored and TensorFlow MLIR Team committed Sep 21, 2023
1 parent f0faec4 commit 95311b2
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 820 deletions.
314 changes: 9 additions & 305 deletions stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,24 @@ func.func private @helper(%arg0: tensor<f32>) -> tensor<f32> {

// -----

// CHECK-LABEL: module @has_main
module @has_main {
// CHECK: main
func.func @main(%arg0: tensor<4xf32>) -> tensor<*xi32> {
// CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<4xi32>
%0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}

// CHECK: helper
func.func @helper(%arg0: tensor<4xf32>) -> tensor<*xi32> {
// CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<*xi32>
%0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}
}

// -----

// CHECK-LABEL: func @error_unsupported_operation
func.func @error_unsupported_operation(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> index {
// CHECK: stablehlo.add{{.*}} -> tensor<?xf32>
%0 = stablehlo.add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<?xf32>
Expand Down Expand Up @@ -467,316 +472,15 @@ func.func @refine_bitcast_convert_different_bitwidths(%arg0 : tensor<4xf32>) ->

// -----

// CHECK-LABEL: func @refine_bitcast_convert_same_bitwidth_unranked_result
func.func @refine_bitcast_convert_same_bitwidth_unranked_result(%arg0 : tensor<4xf32>) -> tensor<*xi32> {
// CHECK-LABEL: func @refine_bitcast_convert_same_bitwidth
func.func @refine_bitcast_convert_same_bitwidth(%arg0 : tensor<4xf32>) -> tensor<*xi32> {
// CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<4xi32>
%0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}

// -----

// CHECK-LABEL: func @refine_bitcast_convert_same_bitwidth
func.func @refine_bitcast_convert_same_bitwidth() -> tensor<?x?x0xf32> {
%0 = stablehlo.constant dense<[3, 5, 0]> : tensor<3xi32>
%21 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<3xi32>) -> tensor<?x?x0xui32>
// CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<3x5x0xf32>
%48 = stablehlo.bitcast_convert %21 : (tensor<?x?x0xui32>) -> tensor<?x?x0xf32>
return %48 : tensor<?x?x0xf32>
}

// -----

// CHECK-LABEL: module @refine_call
module @refine_call {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%1 = stablehlo.constant dense<4> : tensor<i32>
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
%2 = call @refine_call_callee(%1, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK: refine_call_callee(%arg0: tensor<4xf32>) -> tensor<4xf32>
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: stablehlo.constant dense<4>
%0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
return %1 : tensor<?xf32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_arguments
module @refine_call_dimension_arguments {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = call @callee
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<3> : tensor<i32>
%1 = call @callee(%0, %0, %arg0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// %arg0 and %arg1 are dimension arguments
// CHECK: @callee([[ARG0:%.*]]: tensor<i32>) -> tensor<i32>
func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
// CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg1: tensor<i32>
%1 = stablehlo.add %0, %arg2: tensor<i32>
return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: module @refine_call_prefix_token_and_dimension_arguments
module @refine_call_prefix_token_and_dimension_arguments {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = call @callee
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<3> : tensor<i32>
%token = stablehlo.create_token : !stablehlo.token
%1 = call @callee(%token, %0, %0, %arg0) : (!stablehlo.token, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// %arg0 and %arg1 are dimension arguments
// CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
func.func private @callee(%arg_token: !stablehlo.token, %arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
// CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg1: tensor<i32>
%1 = stablehlo.add %0, %arg2: tensor<i32>
return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_arguments_followed_by_token
module @refine_call_dimension_arguments_followed_by_token {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = call @callee
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<3> : tensor<i32>
%token = stablehlo.create_token : !stablehlo.token
%1 = call @callee(%0, %0, %token, %arg0) : (tensor<i32>, tensor<i32>, !stablehlo.token, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// %arg0 and %arg1 are dimension arguments
// CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg_token: !stablehlo.token, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
// CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg1: tensor<i32>
%1 = stablehlo.add %0, %arg2: tensor<i32>
return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: module @refine_multiple_call_with_same_context
module @refine_multiple_call_with_same_context {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
%2 = call @refine_call_callee(%arg0_new, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

// CHECK-LABEL: module @refine_multiple_call_constant_function
module @refine_multiple_call_constant_function {
func.func @main(%arg0: tensor<5xf32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<16>
// CHECK: return [[RESULT0]]
%0 = stablehlo.constant dense<4> : tensor<i32>
%1 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
%2 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
%3 = stablehlo.add %1, %2: tensor<i32>
return %3 : tensor<i32>
}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> tensor<i32> {
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<8>
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg0: tensor<i32>
return %0 : tensor<i32>
}
}

// -----

module @refine_call_multiple_with_different_number_dimension_arguments {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
// Ensure that the first argument is not a constant at the second call site
%arg0_different_f32 = stablehlo.bitcast_convert %arg0_new : (tensor<i32>) -> tensor<f32>
%arg0_different_i32 = stablehlo.bitcast_convert %arg0_different_f32 : (tensor<f32>) -> tensor<i32>
// expected-error@+1{{incorrect number of operands for callee}}
%2 = call @refine_call_callee(%arg0_different_i32, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// expected-error@+1{{Function refine_call_callee has already been refined with a different refinement context. Previous context had 1 and now we have 2 non-dimension arguments}}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

module @refine_call_multiple_different_dimension_arguments {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
%arg0_different = stablehlo.add %arg0_new, %arg0_new : tensor<i32>
// expected-error@+1{{incorrect number of operands for callee}}
%2 = call @refine_call_callee(%arg0_different, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// expected-error@+1{{Function refine_call_callee has already been refined with a different refinement context.}}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

module @refine_call_multiple_different_non_dimension_arguments {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
%2 = stablehlo.constant dense<[1., 2.]> : tensor<2xf32>
%3 = stablehlo.concatenate %1, %2, dim = 0 : (tensor<?xf32>, tensor<2xf32>) -> tensor<?xf32>
// expected-error@+1{{incorrect number of operands for callee}}
%4 = call @refine_call_callee(%arg0_new, %3) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %4 : tensor<?xf32>
}
// expected-error@+1{{Function refine_call_callee has already been refined with a different refinement context.}}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

module @refine_call_recursive {
func.func @main() -> tensor<i32> {
%0 = stablehlo.constant dense<3> : tensor<i32>
%1 = call @refine_call_callee(%0) : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// expected-error@+1{{Function refine_call_callee is being refined recursively}}
func.func @refine_call_callee(%arg0: tensor<i32>) -> tensor<i32> {
// expected-error@+1{{incorrect number of operands}}
%0 = call @refine_call_callee(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}

// -----

module @refine_call_main_argument_unranked {
// expected-error@+1{{main must be refined with static shape arguments}}
func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%2 = call @callee(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %2 : tensor<*xi32>
}
func.func private @callee(%arg0: tensor<*xi32>) -> tensor<*xi32> {
return %arg0 : tensor<*xi32>
}
}

// -----

module @refine_call_main_argument_dynamic_shape {
// expected-error@+1{{main must be refined with static shape arguments}}
func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%2 = call @callee(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %arg0 : tensor<?xi32>
}
}

// -----

module @refine_call_callee_argument_unranked {
func.func public @main(%arg0: tensor<1xi64>) -> tensor<*xi32> {
%1 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xi64>) -> tensor<*xi32>
%2 = call @callee(%1) : (tensor<*xi32>) -> tensor<*xi32>
return %2 : tensor<*xi32>
}
// expected-error@+1{{callee must be refined with static shape arguments}}
func.func private @callee(%arg0: tensor<*xi32>) -> tensor<*xi32> {
return %arg0 : tensor<*xi32>
}
}

// -----

module @refine_call_callee_argument_dynamic_shape {
func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32> {
%1 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xi64>) -> tensor<?xi32>
%2 = call @callee(%1) : (tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
// expected-error@+1{{callee must be refined with static shape arguments}}
func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %arg0 : tensor<?xi32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_argument_non_scalar
// The non-scalar constant is not folded into the callee
module @refine_call_dimension_argument_non_scalar {
func.func public @main() -> tensor<4xi32> {
// CHECK: dense<[1, 2, 3, 4]> : tensor<4xi32>
%0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%1 = call @callee(%0) : (tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
func.func private @callee(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: return %arg0 : tensor<4xi32>
return %arg0 : tensor<4xi32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_argument_not_integer
module @refine_call_dimension_argument_not_integer {
func.func public @main() -> tensor<f32> {
%0 = stablehlo.constant dense<3.> : tensor<f32>
// CHECK: call @callee({{.*}}) : (tensor<f32>) -> tensor<f32>
%2 = call @callee(%0) : (tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}
func.func private @callee(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
}

// -----

// TODO(#1037): Switch to *xi32 once fixed.
// CHECK-LABEL: func @refine_convert
func.func @refine_convert(%arg0 : tensor<4xf32>) -> tensor<?xi32> {
Expand Down
Loading

0 comments on commit 95311b2

Please sign in to comment.